Skip to content

chaennlee/Diffusion-Policies-for-Offline-RL

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Diffusion Policies for Offline RL — Official PyTorch Implementation

Diffusion Policies as an Expressive Policy Class for Offline Reinforcement Learning
Zhendong Wang, Jonathan J Hunt and Mingyuan Zhou
https://arxiv.org/abs/2208.06193

Abstract: Offline reinforcement learning (RL), which aims to learn an optimal policy using a previously collected static dataset, is an important paradigm of RL. Standard RL methods often perform poorly at this task due to the function approximation errors on out-of-distribution actions. While a variety of regularization methods have been proposed to mitigate this issue, they are often constrained by policy classes with limited expressiveness that can lead to highly suboptimal solutions. In this paper, we propose representing the policy as a diffusion model, a recent class of highly-expressive deep generative models. We introduce Diffusion Q-learning (Diffusion-QL) that utilizes a conditional diffusion model for behavior cloning and policy regularization. In our approach, we learn an action-value function and we add a term maximizing action-values into the training loss of the conditional diffusion model, which results in a loss that seeks optimal actions that are near the behavior policy. We show the expressiveness of the diffusion model-based policy, and the coupling of the behavior cloning and policy improvement under the diffusion model both contribute to the outstanding performance of Diffusion-QL. We illustrate the superiority of our method compared to prior works in a simple 2D bandit example with a multimodal behavior policy. We further show that our method can achieve state-of-the-art performance on the majority of the D4RL benchmark tasks for offline RL.

Experiments

Requirements

Installations of PyTorch, MuJoCo, and D4RL are needed. Please see the requirements.txt for environment set up details.

Running

Running experiments based our code could be quite easy, so below we use walker2d-medium-expert-v2 dataset as an example.

For reproducing the optimal results, we recommend running with 'online model selection' as follows. The best_score will be stored in the best_score_online.txt file.

python main.py --env_name walker2d-medium-expert-v2 --device 0 --ms online --lr_decay

For conducting 'offline model selection', run the code below. The best_score will be stored in the best_score_offline.txt file.

python main.py --env_name walker2d-medium-expert-v2 --device 0 --ms offline --lr_decay --early_stop

Hyperparameters for Diffusion-QL have been hard coded in main.py for easily reproducing our reported results. Definitely, there could exist better hyperparameter settings. Feel free to have your own modifications.

Citation

If you find this open source release useful, please cite in your paper:

@article{wang2022diffusion,
  title={Diffusion Policies as an Expressive Policy Class for Offline Reinforcement Learning},
  author={Wang, Zhendong and Hunt, Jonathan J and Zhou, Mingyuan},
  journal={arXiv preprint arXiv:2208.06193},
  year={2022}
}

About

Cartpole environment (DQN, A2C, PPO)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 72.2%
  • Python 27.8%