Lirui Wang, Xinlei Chen, Jialiang Zhao, Kaiming He
Neural Information Processing Systems (Spotlight), 2024
This is a pytorch implementation for pre-training Heterogenous Pre-trained Transformers (HPTs). The pre-training procedure train on mixture of embodiment datasets with a supervised learning objective. The pre-training process can take some time, so we also provide pre-trained checkpoints below. You can find more details on our project page. An alternative clean implementation of HPT in Hugging Face can also be found here.
TL;DR: HPT aligns different embodiment to a shared latent space and investigates the scaling behaviors in policy learning. Put a scalable transformer in the middle of your policy and don’t train from scratch!
pip install -e .
Install (old-version) Mujoco
mkdir ~/.mujoco
cd ~/.mujoco
wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco210.tar.gz --no-check-certificate
tar -xvzf mujoco210.tar.gz
# add the following line to ~/.bashrc if needed
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${HOME}/.mujoco/mujoco210/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64
export MUJOCO_GL=egl
- Check out
quickstart.ipynb
for how to use the pretrained HPTs. python -m hpt.run
train policies on each environment. Add+mode=debug
for debugging.bash experiments/scripts/metaworld/train_test_metaworld_1task.sh test test 1 +mode=debug
for example script.- Change
train.pretrained_dir
for loading pre-trained trunk transformer. The model can be loaded either from local checkpoint folder or huggingface repository. - Run the following scripts for mujoco experiments.
Metaworld 20 Task Experiments
bash experiments/scripts/metaworld/train_test_metaworld_20task_finetune.sh hf://liruiw/hpt-base
- For training, it requires a dataset conversion
convert_dataset
function for packing your own datasets. Check this for example. - For evaluation, it requires a
rollout_runner.py
file for each benchmark and alearner_trajectory_generator
evaluation function that provides rollouts. - If needed, modify the config for changing the perception stem networks and action head networks in the models. Take a look at
realrobot_image.yaml
for example script in the real world. - Add
dataset.use_disk=True
for saving and loading the dataset in disk.
You can find pretrained HPT checkpoints here. At the moment we provide the following model versions:
Model | Size |
---|---|
HPT-XLarge | 226.8M Params |
HPT-Large | 50.5M Params |
HPT-Base | 12.6M Params |
HPT-Small | 3.1M Params |
HPT-Base (With Language) | 50.6M Params |
├── ...
├── HPT
| ├── data # cached datasets
| ├── output # trained models and figures
| ├── env # environment wrappers
| ├── hpt # model training and dataset source code
| | ├── models # network models
| | ├── datasets # dataset related
| | ├── run # transfer learning main loop
| | ├── run_eval # evaluation main loop
| | └── ...
| ├── experiments # training configs
| | ├── configs # modular configs
└── ...
If you find HPT useful in your research, please consider citing:
@inproceedings{wang2024hpt,
author = {Lirui Wang, Xinlei Chen, Jialiang Zhao, Kaiming He},
title = {Scaling Proprioceptive-Visual Learning with Heterogeneous Pre-trained Transformers},
booktitle = {Neurips},
year = {2024}
}
If you have any questions, feel free to contact me through email ([email protected]). Enjoy!