Skip to content

Latest commit

 

History

History
41 lines (26 loc) · 984 Bytes

README.md

File metadata and controls

41 lines (26 loc) · 984 Bytes

RL framework in JAX (Flax)

Offline RL framework in JAX that supports

  • Distributed multi-device training (GPU/TPU)
  • Multiprocessed dataloaders (similar to PyTorch)
  • Modular experiment management and logging

Currently used by myself for research so features may change often.

Install dependencies

conda env create -f environment.yml

conda activate rjax

# Installs dependencies for GPU, don't run if you want CPU install only
pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

pip install -e .

To install on TPU

./gcp/setup_tpu.sh

There are also two Dockerfiles provided for CPU and GPU.

Documentation

No official documentation page yet but the code is extensively documented and configs are explained.

Examples

IQL: implicit_q_learning.

python examples/train_iql.py --prefix iql --env antmaze-medium-play-v2 --num_workers 2