Skip to content

A JAX implementation of "Is Conditional Generative Modeling all you need for Decision-Making?"

License

Notifications You must be signed in to change notification settings

zbzhu99/decision-diffuser-jax

Repository files navigation

Decision Diffuser JAX

This is a JAX implementation of Decision Diffuser. The code is built upon another diffusion-based offline rl algorithm, edp, which is also included in this repo.

Setup the environment

Create python environment with conda

conda env create -f environment.yml
conda activate diffuser
pip install -e .

Apart from this, you'll have to setup your MuJoCo environment and key as well.

Run Experiments

Run diffuser on d4rl hopper:

python train.py --config configs/diffuser_inv_hopper/diffuser_inv_hopper_mdexpert.py

Run EDP on d4rl hopper:

python train.py --config configs/dql_hopper/dql_hopper_mdexpert.py

Current results on D4RL datasets

Weights and Biases Online Visualization Integration

This codebase can also log to W&B online visualization platform. To log to W&B, you first need to set your W&B API key environment variable. Alternatively, you could simply run wandb login.

Credits

This code repo is mainly built upon EDP. We also refer to the official pytorch implementation of decision-diffuser. The vectorized rl envionment is borrowed from tianshou.

About

A JAX implementation of "Is Conditional Generative Modeling all you need for Decision-Making?"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published