Diffusion models have recently emerged as the de facto standard for generating complex high-dimensional outputs. You may know it for its amazing AI art and ability to generate hyper-realistic synthetic images, but it has also found success in other applications such as drug design and continuous control. The core idea of diffusion models is to iteratively transform random noise into samples, such as images or protein structures. This is typically motivated by a maximum likelihood estimation problem, where a model is trained to generate samples that match the training data as closely as possible.
However, most use cases for diffusion models are not directly related to training data matching, but instead are related to downstream goals. We want images that have a certain type of appearance, not images that look like existing images. We don’t just want a physically plausible drug molecule, we want a drug molecule that is as effective as possible. This post shows how you can directly train diffusion models on these downstream goals using reinforcement learning (RL). To this end, we fine-tune Stable Diffusion for various goals, including image compressibility, human-perceived aesthetic quality, and fast image alignment. The last of these goals uses feedback from large-scale vision language models to improve model performance on unusual prompts, demonstrating how powerful AI models can be used to improve each other without humans being involved in the loop.
Diagram showing prompt image alignment goals. We use LLaVA, a large-scale vision language model, to evaluate the generated images.
Optimize denoising diffusion policy
When converting diffusion into a RL problem, we make only the most basic assumptions. Given a sample (e.g. an image), we have access to a reward function that allows us to evaluate how “good” that sample is. Our goal is for the diffusion model to generate samples that maximize this reward function.
Diffusion models are typically trained using a loss function derived from maximum likelihood estimation (MLE). That is, you are encouraged to generate samples that make the training data look more likely. In a RL setting, there is no more training data, only samples and associated rewards from the diffusion model. One way to continue using the same MLE-based loss function is to treat the samples as training data and incorporate the rewards by weighting the loss of each sample according to the reward. This presents an algorithm that we call reward-weighted regression (RWR), after existing algorithms in the RL literature.
However, there are some problems with this approach. One is that RWR is not a particularly accurate algorithm. That is, it only roughly maximizes the reward (see Nair et. al., Appendix A). The MLE-inspired diffusion loss is also inaccurate and is instead derived using transformation bounds on the true likelihood of each sample. This means that RWR maximizes the reward through two levels of approximation, which significantly reduces performance.
We evaluate two variants of DDPO and two variants of RWR for three reward functions and find that DDPO consistently achieves the best performance.
The key insight of the algorithm, called denoising diffusion policy optimization (DDPO), is that we can better maximize the reward of the final sample if we pay attention to the entire sequence of denoising steps. To achieve this, we reformulate the diffusion process as a multistage Markov decision process (MDP). In MDP terms, each denoising step is a task, and the agent receives a reward only at the last step of each denoising trajectory, when the final sample is generated. This framework allows the application of many powerful algorithms from the RL literature specifically designed for multilevel MDP. These algorithms are very easy to compute because they use the exact likelihood of each denoising step instead of using the approximate likelihood of the final sample.
We decided to apply the policy gradient algorithm due to its ease of implementation and past success in fine-tuning language models. This resulted in two variants of DDPO:science fictionwe use a simple score function estimator of the policy gradient, also known as REINFORCE, and DDPO.am, which uses a more robust importance sampling estimator. DDPOam It is our best performing algorithm and its implementation closely follows that of Proximity Policy Optimization (PPO).
Stable diffusion fine-tuning using DDPO
For the main results, we fine-tuned Stable Diffusion v1-4 using DDPO.am. We have four tasks, each defined by a different reward function.
- Compressibility: How easy is it to compress an image using the JPEG algorithm? The compensation is the negative file size (in kB) of the image when saved as a JPEG.
- Incompressibility: How difficult is it to compress an image using the JPEG algorithm? The compensation is the positive file size (in kB) of the image when saved as a JPEG.
- Aesthetic quality: How aesthetically appealing is the image to the human eye? The reward is the output of the LAION aesthetic predictor, a neural network trained on human preferences.
- Prompt-Image Alignment: How well does the image represent what is requested in the prompt? This is a bit more complicated. You feed an image into LLaVA, request an image description, and then use BERTScore to calculate the similarity between that description and the original prompt.
Since Stable Diffusion is a text-to-image model, you also need to choose a set of prompts to provide during fine-tuning. The first three tasks use simple prompts in the following format: “animal)”. Use the following format for prompt image alignment: “a(n) (animal) (activity)”where the action is “dish wash”, “Play chess”and “Ride a bike”. We found that Stable Diffusion often struggled to produce images that matched the prompts for these unusual scenarios, leaving a lot of room for improvement through RL fine-tuning.
We first describe the performance of DDPO for simple compensation (compressibility, incompressibility, and aesthetic quality). All images are generated from the same random seed. The upper left quadrant shows what “vanilla” stable diffusion produces for nine different animals. All RL fine-tuned models show clear qualitative differences. Interestingly, the aesthetic quality model (top right) tends to favor minimalist black-and-white line drawings, revealing the kinds of images that the LAION aesthetic predictor deems “more aesthetic.”
Next, we demonstrate DDPO for a more complex prompted image alignment task. Here are several snapshots of the training process. Each series of three images shows samples for the same prompt and random seed over time, with the first sample coming from Vanilla Stable Diffusion. Interestingly, the model has been changed to a more cartoony style, which is not what was intended. We hypothesize that this is because animals engaging in human-like activities are more likely to appear in a cartoon-like style in pre-training data, so the model leverages information it already knows to switch to this style to more easily fit the prompt.
an unexpected generalization
It has been shown that surprising generalization occurs when using RL to fine-tune large-scale language models. For example, we fine-tuned a model that only follows commands in English. In other languages it is often improved. We found that the same phenomenon occurs in text-image diffusion models. For example, our aesthetic quality model was fine-tuned using prompts selected from a list of 45 common animal species. We found that it generalizes to everyday objects as well as invisible animals.
Our fast image sorting model used the same list of 45 common animals during training and only three activities. We found that it generalizes not only to unseen animals, but also to unseen activities and even novel combinations of the two.
over-optimization
It is well known that fine-tuning of the reward function, especially the learned function, can lead to reward over-optimization, where the model exploits the reward function to achieve high rewards in a way that is not useful. Our setting is no exception. In all tasks, the model eventually discards all meaningful image content to maximize reward.
We also found that LLaVA is vulnerable to a typographical attack: when optimizing alignment with respect to the form’s prompts. “(n) animal”DDPO was able to successfully fool LLaVA by generating text that roughly resembled the correct number.
Currently, there is no universal method to prevent overoptimization, highlighting this issue as an important area for future work.
conclusion
Diffusion models are unrivaled in generating complex, high-dimensional output. However, so far, most success has been achieved in applications where the goal is to learn patterns from large amounts of data (e.g., image-caption pairs). What we discovered is a way to effectively train diffusion models in a way that goes beyond pattern matching, and in a way that doesn’t necessarily require training data. The possibilities are limited only by the quality and creativity of the reward features.
The use of DDPO in this work is inspired by recent successes in language model fine-tuning. OpenAI’s GPT models, like Stable Diffusion, are first trained on massive amounts of internet data. It is then fine-tuned with RL to produce useful tools such as ChatGPT. Typically their reward functions are learned from human preferences, but others have more. Recently Instead, we figured out how to create a powerful chatbot with a reward function based on AI feedback. Compared to chatbot regimes, our experiments are small in scale and limited in scope. However, considering the enormous success of this “pre-training + fine-tuning” paradigm in language modeling, it seems worth pursuing further in the world of diffusion models. We hope that others can build on our work to improve large-scale diffusion models for many interesting applications, such as not only text-to-image generation, but also video generation, music generation, image editing, protein synthesis, robotics, etc.
Moreover, the “pre-training + fine-tuning” paradigm is not the only way to use DDPO. There is nothing preventing you from training with RL from the beginning, as long as you have a good reward function. This setting has yet to be explored, but it’s where DDPO’s strengths can really shine. Pure RL has long been applied in a variety of areas, from gameplay to robot manipulation, nuclear fusion, and chip design. Adding the expressive power of diffusion models to the mix has the potential to take existing RL applications to the next level or discover new ones.
This post is based on the following paper:
If you want to know more about DDPO, you can check out the paper, website, original code, or check out the model weights on Hugging Face. If you want to use DDPO in your own projects, check out our PyTorch + LoRA implementation that allows you to fine-tune Stable Diffusion with less than 10GB of GPU memory!
If DDPO has inspired your work, please cite it:
@misc{black2023ddpo,
title={Training Diffusion Models with Reinforcement Learning},
author={Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
year={2023},
eprint={2305.13301},
archivePrefix={arXiv},
primaryClass={cs.LG}
}