Cold PAWS: Unsupervised class discovery and addressing the cold-start problem for semi-supervised learning
This is the pytorch code for fitting the models from the paper. For the label selection strategy code see this repo.
To train the SimCLR model on CIFAR10, run
python main.py --config config/miscdata_paws_small.py
For convenience, we have provided a CIFAR-10 pretrained model in the pretrained_models folder. You can download it from here along with some sample output files.
To train a supervised model from these pretrained weights, run
python main.py --config config/miscdata_small_pyz_finetune_simclr_nodist.py
To write the SimCLR encodings to file for this pretrained model, run
python main.py --config config/miscdata_small_pyz_finetune_simclr_nodist_writedata.py
In the indices folder, we have some example input files for selecting small labelled subsets for training.
These can be finetuned with
python main.py --config config/miscdata_small_pyz_finetune_simclr_nodist_some.py
These can be fit using paws with
python main.py --config config/miscdata_paws_small.py