From 302ebe8e6a8c7083e9db0ee74f36b45bf031e405 Mon Sep 17 00:00:00 2001 From: Jo Walsh Date: Tue, 12 Nov 2024 17:31:04 +0000 Subject: [PATCH] Add a num_classes option to test runner, improve docs --- test/README.md | 20 ++++++++++++++++---- test/test.py | 38 ++++++++++++++++++++++++++++++++++---- tools/loss.py | 0 3 files changed, 50 insertions(+), 8 deletions(-) mode change 100644 => 100755 test/test.py mode change 100644 => 100755 tools/loss.py diff --git a/test/README.md b/test/README.md index 5510f82..5703209 100644 --- a/test/README.md +++ b/test/README.md @@ -1,20 +1,32 @@ ## Notes on reproducing -With the model weights downloaded from Google Drive and unpacked into the `vit_model` directory below: +With the model weights downloaded from Google Drive (see [this issue](https://github.com/alan-turing-institute/ViT-LASNet/issues/2) for a query about where they could be held in a repository with model card to support reuse). + +Images are a small test set from [this project on freshwater plankton](https://github.com/NERC-CEH/plankton_ml/). + +### ViT model (version = 2, num_classes = 18) + +Note this is the version in the subdirectory with `model.safetensors` included, not the single-file `.pt` version. ``` -python test.py -w ../vit_model/vit_finetuned_Bal_CE_lr5e-05_epochs10 -o out -m 2 -f ../../plankton_ml/tests/fixtures/test_images/ +python test.py -w ~/vit_finetuned_MiSLAS_vit_lr5e-05_epochs30/ -o out -n 18 -m 2 -f ../../plankton_ml/tests/fixtures/test_images/ ``` -OSError: Error no file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory ../vit_model/vit_finetuned_Bal_CE_lr5e-05_epochs10. +### 3 class ResNet18 With the 3-class Resnet18 weights and the model_version 1 ("combined) we see output: ``` python test.py -w ../ResNet_18_3classes_RGB.pth -o out.csv -m 1 -f ../../plankton_ml/tests/fixtures/test_images/ ``` -With `-m 0` different `size mismatch` errors on model loading for both flavours +### 18 class ResNet18 +As above but with an `-n` option to specify the model classes (used both to initialise and for predictions). +Defaults to 3, will raise an error with values other than 18 + +``` +python test.py -w ../ResNet_18_18classes_RGB.pth -o out.csv -m 1 -n 18 -f ../../plankton_ml/tests/fixtures/test_images/ +``` diff --git a/test/test.py b/test/test.py old mode 100644 new mode 100755 index ddf048a..37f24d3 --- a/test/test.py +++ b/test/test.py @@ -16,13 +16,34 @@ from transformers import AutoImageProcessor, ViTForImageClassification, BeitImageProcessor, BeitForImageClassification device = "mps" if torch.backends.mps.is_available() else "cpu" -print(f"Using device: {device}") +logging.info(f"Using device: {device}") -LABELS = [ +CLASS_LABELS = { 3 : [ "copepod", "detritus", "noncopepod", +], + 18 : [ + "Detritus", + "Phyto_diatom", + "Phyto_diatom_chaetocerotanae_Chaetoceros", + "Phyto_diatom_rhisoleniales_Guinardia flaccida", + "Phyto_diatom_rhisoleniales_Rhizosolenia", + "Phyto_dinoflagellate_gonyaulacales_Tripos", + "Phyto_dinoflagellate_gonyaulacales_Tripos macroceros", + "Phyto_dinoflagellate_gonyaulacales_Tripos muelleri", + "Zoo_cnidaria", + "Zoo_crustacea_copepod", + "Zoo_crustacea_copepod_calanoida", + "Zoo_crustacea_copepod_calanoida_Acartia", + "Zoo_crustacea_copepod_calanoida_Centropages", + "Zoo_crustacea_copepod_cyclopoida", + "Zoo_crustacea_copepod_cyclopoida_Oithona", + "Zoo_crustacea_copepod_nauplii", + "Zoo_other", + "Zoo_tintinnidae" ] +} def resnet50(num_classes): model = torchvision.models.resnet50() @@ -140,9 +161,18 @@ def classify_batch(image_list, device, model, processor=None, gray=False, batch_ parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size for processing images") parser.add_argument("-o", "--output_csv", type=str, help="CSV file to save output in") parser.add_argument("-w", "--weights", type=str, help="Optional path to model weights") + parser.add_argument("-n", "--num_classes", type=int, default="3", help="Optional number of class labels (default 3)") + args = parser.parse_args() device = get_device() + + num_classes = args.num_classes + try: + LABELS = CLASS_LABELS[num_classes] + except IndexError: + raise(f"Can't find a set of {num_classes} labels") + weights = args.weights if args.model_version == 2: if not weights: @@ -193,7 +223,7 @@ def classify_batch(image_list, device, model, processor=None, gray=False, batch_ writer.writeheader() for i in range(len(filenames_list)): writer.writerow({"Filename": filenames_list[i], "Predicted Class": results[i]}) - print(f"Results saved to {output_csv_file}") + logging.info(f"Results saved to {output_csv_file}") else: with open(args.filename, "rb") as file: @@ -205,4 +235,4 @@ def classify_batch(image_list, device, model, processor=None, gray=False, batch_ writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() writer.writerow({"Filename": args.filename, "Predicted Class": result}) - print(f"Result saved to {output_csv_file}") + logging.info(f"Result saved to {output_csv_file}") diff --git a/tools/loss.py b/tools/loss.py old mode 100644 new mode 100755