This project provides a Pytorch implementation of MONet with proposed extensions for incorporating temporal information across sequential frames in applicable domains (eg. games, videos) as well placing additional spatial constraints on produced segmentations to improve their spatial consistency.
Base implementation of MoNet is provided by the project MONet-pytorch.
This repository is developed and tested with python 3.7.3 on Ubuntu18.04. If you are missing python header files, install python 3.7-dev
Run the setup script:
./setup.sh
By default, models are configured for running on CUDA GPUs. To run on CPU, set the gpu_ids option to -1 as below:
python train.py --gpu_ids -1
- Download a CLEVR dataset:
wget -cN https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip
- Train a model:
python train.py --gpu_ids -1 --batch_size 50 --dataroot /path/to/clevr_dataset --dataset_mode clevr --num_slots 4
For training on Atari, a static datset can be used as input or batches can be generated dynamically from a random or pretrained agent.
Run the train command with the following options:
# Collect frames from episode rollouts executed with a pretrained ppo agent
python train.py --game RiverraidNoFrameskip-v4 --dynamic_datagen --collect_mode pretrained_ppo
# Collect frames from episode rollouts executed with a random agent
python train.py --game RiverraidNoFrameskip-v4 --dynamic_datagen --collect_mode random_agent
Run the train commmand with the following options:
python train.py --dataroot /path/to/atari_dataset
Note: See atari-representation-learning repository for generating frames to build an Atari dataset. generate_atari_frames.py contains examples for generating a dataset.
We utilize Visdom for producing visuals during training. Launch the visdom server before starting the training. By default, the server is hosted on port 8097.
python -m visdom.server