From ae52201af5a635c9bfa4a0d5c719c0fbf3898caf Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Tue, 17 Oct 2023 14:12:48 -0400 Subject: [PATCH] download=True --- .../finetuning/finetune_multiclass_classification.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/zoobot/pytorch/examples/finetuning/finetune_multiclass_classification.py b/zoobot/pytorch/examples/finetuning/finetune_multiclass_classification.py index 3b16e1b7..02dde83e 100644 --- a/zoobot/pytorch/examples/finetuning/finetune_multiclass_classification.py +++ b/zoobot/pytorch/examples/finetuning/finetune_multiclass_classification.py @@ -18,13 +18,14 @@ num_workers= 8 n_blocks = 1 # EffnetB0 is divided into 7 blocks. set 0 to only fit the head weights. Set 1, 2, etc to finetune deeper. max_epochs = 6 # 6 epochs should get you ~93% accuracy. Set much higher (e.g. 1000) for harder problems, to use Zoobot's default early stopping. + # the remaining key parameters for high accuracy are weight_decay, learning_rate, and lr_decay. You might like to tinker with these. # load in catalogs of images and labels to finetune on # each catalog should be a dataframe with columns of "id_str", "file_loc", and any labels # here I'm using galaxy-datasets to download some premade data - check it out for examples - train_catalog, _ = galaxy_mnist(root=data_dir, download=False, train=True) - test_catalog, _ = galaxy_mnist(root=data_dir, download=False, train=False) + train_catalog, _ = galaxy_mnist(root=data_dir, download=True, train=True) + test_catalog, _ = galaxy_mnist(root=data_dir, download=True, train=False) # wondering about "label_cols"? # This is a list of catalog columns which should be used as labels