Skip to content

erow/FastSSL

Repository files navigation

Install

conda create -y -n  FastSSL python=3.10 cupy pkg-config libjpeg-turbo=3.0.0 opencv numba  pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia -c conda-forge 
conda activate FastSSL
conda install 
pip install -r requirements.txt

build dataset

Create FFCV data by

python dataset/write_dataset.py \
    --cfg.dataset=imagenet \
    --cfg.write_mode=jpg \
    --cfg.jpeg_quality=95 --cfg.max_resolution=500 \
    --cfg.data_dir=$IMAGENET_DIR/train \
    --cfg.write_path=$write_path 

Or run bin/build_data.sh to build the datasets.

Regarding to loading ffcv datasets, please refer to ffcv. Use FFCV_DEFAULT_CACHE_PROCESS=0 to enable cache process.

Before training, run python bin/benchmark_data.py --data_path=<*.ffcv> to check the data loading speed.

Train

torchrun

Run command

torchrun --nproc_per_node 8 main_pretrain.py  --data_path=${train_path} --data_set=ffcv \
    --epochs 800 --warmup_epochs 40 --blr 1.5e-4 --weight_decay 0.05 --batch_size 512\
    --cfgs configs/mae_ffcv.gin --gin build_model.model_fn=@base/MaskedAutoencoderViT build_dataset.transform_fn=@SimplePipeline  --ckpt_freq=100 --output_dir outputs/IN1K_base 

Optional arguments: --compile to compile the model, --ckpt_freq to save checkpoints every ckpt_freq epochs, --online_prob to evaluate the linear classifier during training.

HPC

The original settings for ViT-Large are bs=4096, epochs=800 ~42h in 64 V100 GPUs.

WANDB_NAME=mae_1k python submitit_pretrain.py \
    --job_dir ${JOB_DIR} \
    -p gpu --ngpus 8 --nodes 8 \
    --batch_size 64 \
    --epochs 800 \
    --warmup_epochs 40 \
    --blr 1.5e-4 --weight_decay 0.05 \
    --cfgs configs/mae_ffcv.gin --gin build_model.model_fn=@base/MaskedAutoencoderViT build_dataset.transform_fn=@SimplePipeline \
    --data_path=${train_path} --data_set=ffcv --online_prob --dynamic_resolution 

Progresive Training

WANDB_TAGS=dres WANDB_NAME=mae_s1  python submitit_pretrain.py -p long --data_path=${train_path} --data_set=ffcv \
     --epochs 800 --warmup_epochs 40 --blr 1.5e-4 --weight_decay 0.05 --batch_size 128 --accum_iter=4 \
     --cfgs configs/mae_ffcv.gin --gin build_model.model_fn=@base/MaskedAutoencoderViT build_dataset.transform_fn=@SimplePipeline DynamicMasking.start_ramp=0 DynamicMasking.end_ramp=400 DynamicMasking.scheme=1   \
     --ckpt_freq=100 --online_prob --dynamic_resolution 

Tracking

We use wandb to track the training processes.

Advanced Usage

configure

gin is utilized to configure the hyperparameters for models. The programe accepts configure files and key=value pairs to overwrite the hyperparameters by passing arguments --cfgs *.gin --gin k=v. The main also accepts parameters, like --batch_size, --epochs, --data_path, --data_set, --output_dir, to control the training process.

data

Use ffcv loader by passing --data_set=ffcv or pytorch dataloader by passing --data_set=IMNET.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published