Diffusion models are now turbocharging reinforcement learning systems

diffusion world model
Image generated with Bing Image Creator

This article is part of our coverage of the latest in AI research.

Diffusion models are best known for their impressive capabilities to generate highly detailed images. They are the main architecture used in popular text-to-image models such as DALL-E, Stable Diffusion, and Midjourney.

However, diffusion models can be used for more than just generating images. A new paper by researchers at Meta, Princeton University, and University of Texas, Austin, shows that diffusion models can help create better reinforcement learning systems.

The study introduces a technique that uses diffusion-based world models to train RL agents. Diffusion World Models (DWM) enhance current model-based RL systems by being able to predict what the environment will look like several steps ahead.

Model-free vs model-based reinforcement learning

Model-free reinforcement learning algorithms learn a policy or value function directly from interactions with the environment without predicting what the environment will look like in the future. In contrast, model-based reinforcement learning algorithms simulate their environments through world models. These models enable them to predict how their actions will affect their environment and adjust their policies accordingly.

One of the key advantages of model-based RL is that it needs fewer data samples from the real environment. This is especially useful for applications such as self-driving cars and robotics, where gathering data from the real world can be costly or risky.

 However, model-based reinforcement learning highly depends on the accuracy of the world model. In practice, inaccuracies in the world model cause model-based RL systems to perform worse than model-free systems.

Traditional world models use one-step dynamics, which means they only predict the reward and next state based on the current state and action. When planning for multiple steps into the future, the RL system invokes the model recursively with its own output. The problem with this approach is that small errors can compound across multiple steps, making long-horizon predictions unreliable and inaccurate.

The premise of Diffusion World Model (DWM) is to learn to predict multiple future steps in one go. If done correctly, this approach can reduce errors in long-horizon predictions and improve the performance of model-based reinforcement learning algorithms.

How Diffusion World Models work

Diffusion models work based on a simple principle: they learn to generate data by reversing a process that gradually adds noise to the data. For example, when trained to generate images, the model gradually adds layers of noise to the image and then tries to reverse the process and predict the original image. By repeating this process and adding more layers of noise, it learns to create detailed images from pure noise. Conditional diffusion models add a layer of control by conditioning the model’s output to a specific input, such as the caption that goes with the image. This is what enables you to give these model textual descriptions and receive the corresponding image.

Diffusion steps
By Benlisquare – Own work, CC BY-SA 4.0, Link

But while diffusion models are best known for their ability to generate high-quality images, they can also be applied to other data types.

Diffusion World Models (DWM) use the same principle to predict long-horizon outcomes in reinforcement learning systems. Instead of text descriptions, DWM is conditioned on the current state, action, and expected return. Its output is multiple steps of states and rewards into the future. 

The DWM framework has two training stages. In the first stage, the diffusion model is trained on a series of trajectories collected from the environment. DWM learns a strong world model that can predict multiple steps in one pass, making it more stable than other model-based methods in long-horizon simulation.

In the second stage, an offline reinforcement learning policy is trained using the actor-critic algorithm and the diffusion world model. Using offline RL eliminates the need for online interactions during training, which increases speed and reduces costs and risks.

For each step, the agent uses the DWM to generate future trajectories and simulate the return on its actions. The researchers call this “Diffusion Model Value Expansion” (Diffusion-MVE). While the RL system uses the DWM world model during training, the resulting policy is model-free, which has the benefit of faster inference.

“Diffusion-MVE can be interpreted as a value regularization for offline RL through generative modeling, or alternatively, a way to conduct offline Q-learning with synthetic data,” the researchers write.

At a higher level, the main idea behind DWM is to predict multiple world states into the future. Therefore, you can replace the diffusion model with another sequence model. The researchers experimented with transformer models as well but found DWM to be more efficient.

DWM in action

To test the effectiveness of DWM, the researchers compared it to both model-based and model-free reinforcement learning systems. They experimented with three variants of algorithms and nine locomotion tasks from the D4RL dataset.

Their results show that DWM significantly has a 44% performance gain over single-step world models. When single-step world models are applied to model-free RL algorithms, it usually degrades the performance. However, when combined with DWMs, model-free RL systems outperform the original versions, the researchers found.

“This is attributed to the strong expressivity of diffusion models and the prediction of entire sequences all at once, which circumvents the compounding error issue in multistep rollout of traditional one-step dynamics models,” the researchers write. “Our method achieves state-of-the-art (SOTA) performance, eliminating the gap between MB and MF algorithms.”

DWM is part of a broader trend of generative models used in non-generative tasks. In the past year, robotics research has advanced in leaps thanks to advances in generative models. Language models are helping bridge the gap between natural language commands and robot motion commands. Transformers are also helping researchers bring together data gathered from different morphologies and settings and train models that can generalize to different robots and tasks. 

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.