Welcome to the official repository for our paper: "Scaling Offline Model-Based RL via Jointly-Optimized World-Action Model Pretraining".
TL;DR: A single JOWA-150M agent masters 15 Atari games at 84.7% human-level and 119.5% DQN-level, and can adapt to novel games with ~4 expert demos.
*Training used 84x84 grayscale images. RGB demos shown here for better visualization. For Atlantis (second in the first row), only the first 7 minutes of the 2-hour gameplay are displayed. More demos are here.
- python 3.8:
conda create -n jowa python=3.8 && conda activate jowa
- Install basic dependencies for compatibility:
pip install setuptools==65.5.0 wheel==0.38.4 packaging pip==24.0
- Install other dependencies:
pip install -r requirements.txt
- (Optional but recommended) Install Flash-Attention:
pip install flash-attn --no-build-isolation
- If evaluating, install jq to parse JSON files:
apt install jq
- If training, set the wandb account:
wandb login
,<API keys>
Download the model weights from here or google drive.
Use beam search as the planning algorithm. Set hyperparameters in eval/eval.sh
, such as model_name
, game
, num_rollouts
, and num_gpus
. Then run this script:
bash eval/eval.sh
Use MCTS as the planning algorithm. Set hyperparameters in eval/eval_MCTS.sh
. The extra tunable hyperparameters for MCTS are horizon
, temperature
, and num_simulations
. Then run this script:
bash eval/eval_MCTS.sh
No planning algorithm during evaluation.
bash eval/eval_wo_planning.sh
The above commands will run the setting number of rollouts (episodes) in parallel. After completing all evaluation processes, get the aggregate scores using the following command:
python eval/aggregate_scores.py --name JOWA_150M --env BeamRider
Reproduce results of all JOWA variants and baselines through the following command:
python results/results.py
Note
We found that some hyperparameters in a few groups of previous evaluation experiments (shown in the early version of paper) were set incorrectly. After the correction (already done in the current version of code), all JOWA variants achieved higher IQM HNS, and the scaling trend still holds. We will update the paper to show the corrected results.
- Download raw dataset
python src/datasets/download.py
This will enable multi-process to download the original DQN-Replay dataset (~1TB for the default 20 games and 3.1TB for all 60 games).
- Process and downsample data
python src/datasets/downsample.py
This command processes the raw data into two formats: (i) Trajectory-level dataset: A hdf5 file containing transitions for each trajectory. Total size: ~1TB for 15 pretraining games. (ii) Segment-level dataset: CSV files containing segment indices in the correspoding trajectory.
Start pre-training (including stage 1 and 2) with the following torchrun
command:
torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=127.0.0.1 --master_port=39500 src/train.py hydra/job_logging=disabled hydra/hydra_logging=disabled
However, after stage 1, we highly recommend the following steps to speed up training:
- Use the pre-trained vqvae to encode all images in advance:
python src/save_obs_in_token.py
file. - Use the
AtariTrajWithObsTokenInMemory
class as the dataset (uncomment lines 161~170 in thesrc/train.py
file) to speed up loading, since vqvae is frozen in the second stage. - Modify the config
config/train_xxM.yaml
. Set theinitialization.load_tokenizer
to the path of stage-1's checkpoint. Then settraining.action.train_critic_after_n_steps: -1
to train the critic immediately. - Re-run the above
torchrun
command.
The fine-tuning dataset is here. Please unzip this compressed file. Then run the command:
DEVICE_ID=0 torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 --master_addr=127.0.0.1 --master_port=39500 src/fine_tune.py hydra/job_logging=disabled hydra/hydra_logging=disabled common.env=Pong training.action.use_imagination=True training.action.planning_horizon=2
After fine-tuning, modify and run eval/eval_wo_planning.sh
to evaluate the checkpoint.
- Release model weights.
- Optimize evaluation code.
- Optimize data preprocessing.
- Merge and optimize pretraining & fine-tuning codes.
- Multi-env evaluation.