Skip to content

akashsonowal/ddpo-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RLHF for Diffusion Models

Citing: This codebase is clone of work by Tanishq Abraham. I have cloned for my own learning purpose but it may be useful if you are looking for little more organized code (modular and intuitive variable namings) than the original repo.

This is an implementation of Training Diffusion Models with Reinforcement Learning with only basic features. It currently only implements LAION aesthetic classifier as a reward function.

Installation

git clone https://github.com/akashsonowal/ddpo-pytorch.git 
cd ddpo-pytorch
pip install -r requirements.txt

Usage

It's as simple as running:

python main.py

To save memory (you'll likely need it), use the arguments --enable_attention_slicing, --enable_xformers_memory_efficient_attention, and --enable_grad_checkpointing.

Results

Original samples: image

After training for 50 episodes: image