Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a num_classes option to test script, update README #3

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions test/README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@

## Notes on reproducing

With the model weights downloaded from Google Drive and unpacked into the `vit_model` directory below:
With the model weights downloaded from Google Drive (see [this issue](https://github.com/alan-turing-institute/ViT-LASNet/issues/2) for a query about where they could be held in a repository with model card to support reuse).

Images are a small test set from [this project on freshwater plankton](https://github.com/NERC-CEH/plankton_ml/).

### ViT model (version = 2, num_classes = 18)

Note this is the version in the subdirectory with `model.safetensors` included, not the single-file `.pt` version.

```
python test.py -w ../vit_model/vit_finetuned_Bal_CE_lr5e-05_epochs10 -o out -m 2 -f ../../plankton_ml/tests/fixtures/test_images/
python test.py -w ~/vit_finetuned_MiSLAS_vit_lr5e-05_epochs30/ -o out -n 18 -m 2 -f ../../plankton_ml/tests/fixtures/test_images/
```

OSError: Error no file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory ../vit_model/vit_finetuned_Bal_CE_lr5e-05_epochs10.
### 3 class ResNet18

With the 3-class Resnet18 weights and the model_version 1 ("combined) we see output:
```
python test.py -w ../ResNet_18_3classes_RGB.pth -o out.csv -m 1 -f ../../plankton_ml/tests/fixtures/test_images/
```

With `-m 0` different `size mismatch` errors on model loading for both flavours
### 18 class ResNet18

As above but with an `-n` option to specify the model classes (used both to initialise and for predictions).
Defaults to 3, will raise an error with values other than 18

```
python test.py -w ../ResNet_18_18classes_RGB.pth -o out.csv -m 1 -n 18 -f ../../plankton_ml/tests/fixtures/test_images/
```


38 changes: 34 additions & 4 deletions test/test.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,34 @@
from transformers import AutoImageProcessor, ViTForImageClassification, BeitImageProcessor, BeitForImageClassification

device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
logging.info(f"Using device: {device}")

LABELS = [
CLASS_LABELS = { 3 : [
"copepod",
"detritus",
"noncopepod",
],
18 : [
"Detritus",
"Phyto_diatom",
"Phyto_diatom_chaetocerotanae_Chaetoceros",
"Phyto_diatom_rhisoleniales_Guinardia flaccida",
"Phyto_diatom_rhisoleniales_Rhizosolenia",
"Phyto_dinoflagellate_gonyaulacales_Tripos",
"Phyto_dinoflagellate_gonyaulacales_Tripos macroceros",
"Phyto_dinoflagellate_gonyaulacales_Tripos muelleri",
"Zoo_cnidaria",
"Zoo_crustacea_copepod",
"Zoo_crustacea_copepod_calanoida",
"Zoo_crustacea_copepod_calanoida_Acartia",
"Zoo_crustacea_copepod_calanoida_Centropages",
"Zoo_crustacea_copepod_cyclopoida",
"Zoo_crustacea_copepod_cyclopoida_Oithona",
"Zoo_crustacea_copepod_nauplii",
"Zoo_other",
"Zoo_tintinnidae"
]
}

def resnet50(num_classes):
model = torchvision.models.resnet50()
Expand Down Expand Up @@ -140,9 +161,18 @@ def classify_batch(image_list, device, model, processor=None, gray=False, batch_
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size for processing images")
parser.add_argument("-o", "--output_csv", type=str, help="CSV file to save output in")
parser.add_argument("-w", "--weights", type=str, help="Optional path to model weights")
parser.add_argument("-n", "--num_classes", type=int, default="3", help="Optional number of class labels (default 3)")


args = parser.parse_args()
device = get_device()

num_classes = args.num_classes
try:
LABELS = CLASS_LABELS[num_classes]
except IndexError:
raise(f"Can't find a set of {num_classes} labels")

weights = args.weights
if args.model_version == 2:
if not weights:
Expand Down Expand Up @@ -193,7 +223,7 @@ def classify_batch(image_list, device, model, processor=None, gray=False, batch_
writer.writeheader()
for i in range(len(filenames_list)):
writer.writerow({"Filename": filenames_list[i], "Predicted Class": results[i]})
print(f"Results saved to {output_csv_file}")
logging.info(f"Results saved to {output_csv_file}")

else:
with open(args.filename, "rb") as file:
Expand All @@ -205,4 +235,4 @@ def classify_batch(image_list, device, model, processor=None, gray=False, batch_
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
writer.writerow({"Filename": args.filename, "Predicted Class": result})
print(f"Result saved to {output_csv_file}")
logging.info(f"Result saved to {output_csv_file}")
Empty file modified tools/loss.py
100644 → 100755
Empty file.