conda create -y -n FastSSL python=3.10 cupy pkg-config libjpeg-turbo=3.0.0 opencv numba pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia -c conda-forge
conda activate FastSSL
conda install
pip install -r requirements.txt
Create FFCV data by
python dataset/write_dataset.py \
--cfg.dataset=imagenet \
--cfg.write_mode=jpg \
--cfg.jpeg_quality=95 --cfg.max_resolution=500 \
--cfg.data_dir=$IMAGENET_DIR/train \
--cfg.write_path=$write_path
Or run bin/build_data.sh
to build the datasets.
Regarding to loading ffcv datasets, please refer to ffcv.
Use FFCV_DEFAULT_CACHE_PROCESS=0
to enable cache process.
Before training, run python bin/benchmark_data.py --data_path=<*.ffcv>
to check the data loading speed.
Run command
torchrun --nproc_per_node 8 main_pretrain.py --data_path=${train_path} --data_set=ffcv \
--epochs 800 --warmup_epochs 40 --blr 1.5e-4 --weight_decay 0.05 --batch_size 512\
--cfgs configs/mae_ffcv.gin --gin build_model.model_fn=@base/MaskedAutoencoderViT build_dataset.transform_fn=@SimplePipeline --ckpt_freq=100 --output_dir outputs/IN1K_base
Optional arguments: --compile
to compile the model, --ckpt_freq
to save checkpoints every ckpt_freq
epochs, --online_prob
to evaluate the linear classifier during training.
The original settings for ViT-Large are bs=4096, epochs=800 ~42h in 64 V100 GPUs.
WANDB_NAME=mae_1k python submitit_pretrain.py \
--job_dir ${JOB_DIR} \
-p gpu --ngpus 8 --nodes 8 \
--batch_size 64 \
--epochs 800 \
--warmup_epochs 40 \
--blr 1.5e-4 --weight_decay 0.05 \
--cfgs configs/mae_ffcv.gin --gin build_model.model_fn=@base/MaskedAutoencoderViT build_dataset.transform_fn=@SimplePipeline \
--data_path=${train_path} --data_set=ffcv --online_prob --dynamic_resolution
WANDB_TAGS=dres WANDB_NAME=mae_s1 python submitit_pretrain.py -p long --data_path=${train_path} --data_set=ffcv \
--epochs 800 --warmup_epochs 40 --blr 1.5e-4 --weight_decay 0.05 --batch_size 128 --accum_iter=4 \
--cfgs configs/mae_ffcv.gin --gin build_model.model_fn=@base/MaskedAutoencoderViT build_dataset.transform_fn=@SimplePipeline DynamicMasking.start_ramp=0 DynamicMasking.end_ramp=400 DynamicMasking.scheme=1 \
--ckpt_freq=100 --online_prob --dynamic_resolution
We use wandb to track the training processes.
gin is utilized to configure the hyperparameters for models. The programe accepts configure files and key=value pairs to overwrite the hyperparameters by passing arguments --cfgs *.gin --gin k=v
.
The main also accepts parameters, like --batch_size
, --epochs
, --data_path
, --data_set
, --output_dir
, to control the training process.
Use ffcv loader by passing --data_set=ffcv
or pytorch dataloader by passing --data_set=IMNET
.