From 880112434a8869ca62d13c5aaa1fa7b9388a47bf Mon Sep 17 00:00:00 2001 From: Ethan Li Date: Mon, 13 May 2024 13:23:00 -0700 Subject: [PATCH] Try to make the torchvision model architecture selectable by env var --- README.md | 10 ++++++++-- app_model.py | 16 ++++++++++++---- deployments/apps/ps/streamlit-demo.deploy.yml | 1 + ...mpose-torchvision-model-efficientnet-v2-s.yml | 4 ++++ pkg/forklift-package.yml | 5 +++++ 5 files changed, 30 insertions(+), 6 deletions(-) create mode 100644 pkg/compose-torchvision-model-efficientnet-v2-s.yml diff --git a/README.md b/README.md index 882ecc5..bd90718 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,9 @@ To avoid errors later in running the docker container, please import your pretra models/ +By default, models must match Torchvision's `efficientnet_v2_m` architecture; to use models with a +different architecture, you must specify the architecture with a `TORCHVISION_MODEL_TYPE` +environment variable passed into the app. #### Building the Docker Image @@ -121,9 +124,12 @@ forklift stage apply After you have applied the pallet so that the streamlit demo app's container is running, you can access the streamlit demo app from your web browser at . -Before you can use the streamlit demo app, you will need to download a classification model file +Before you can use the demo app, you will need to download a classification model weights file (e.g. ) -into `~/.local/share/planktoscope/models`. +into `~/.local/share/planktoscope/models`; by default the model weights file must be for the +`efficientnet_v2_s` model architecture, but you can use the `efficientnet_v2_m` model architecture +instead by disabling the `torchvision-model-efficientnet-v2-s` feature flag of the pallet's +`apps/ps/streamlit-demo` package deployment. ## License This project is licensed under the [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0). diff --git a/app_model.py b/app_model.py index e5c7c70..c988da3 100644 --- a/app_model.py +++ b/app_model.py @@ -15,7 +15,12 @@ import torch import torch.nn as nn import torchvision.transforms as transforms -from torchvision.models import efficientnet_v2_m +from torchvision.models import efficientnet_v2_m, efficientnet_v2_s + +model_types = { + "efficientnet_v2_m": efficientnet_v2_m, + "efficientnet_v2_s": efficientnet_v2_s, +} ############################################################################################ # Functions/variables to be used in the Streamlit app @@ -42,11 +47,11 @@ def set_theme(theme): st.markdown(light, unsafe_allow_html=True) # Define the model loading function -def load_model(model_path): +def load_model(model_type, model_path): # Load the model checkpoint (remove map_location if you have a GPU) loaded_cpt = torch.load(model_path, map_location=torch.device('cpu')) # Define the EfficientNet_V2_M model (by default, no pre-trained weights are used) - model = efficientnet_v2_m() + model = model_types[model_type]() # Modify the classifier to match the number of classes in the dataset model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, 5) # Load the state_dict in order to load the trained parameters @@ -143,7 +148,10 @@ def main(): image_size = int(re.search(pattern, selected_model).group().split("x")[0]) # Load the selected model in pytorch - model = load_model(os.path.join("models", selected_model)) + model = load_model( + os.getenv("TORCHVISION_MODEL_TYPE", "efficientnet_v2_m"), + os.path.join("models", selected_model), + ) # Load the class labels class_labels = ["Acantharia", "Calanoida", "Neoceratium_petersii", "Ptychodiscus_noctiluca", "Undella"] diff --git a/deployments/apps/ps/streamlit-demo.deploy.yml b/deployments/apps/ps/streamlit-demo.deploy.yml index 8e04969..6625a39 100644 --- a/deployments/apps/ps/streamlit-demo.deploy.yml +++ b/deployments/apps/ps/streamlit-demo.deploy.yml @@ -1,4 +1,5 @@ package: /pkg features: - frontend + - torchvision-model-efficientnet-v2-s disabled: false diff --git a/pkg/compose-torchvision-model-efficientnet-v2-s.yml b/pkg/compose-torchvision-model-efficientnet-v2-s.yml new file mode 100644 index 0000000..b518427 --- /dev/null +++ b/pkg/compose-torchvision-model-efficientnet-v2-s.yml @@ -0,0 +1,4 @@ +services: + server: + environment: + TORCHVISION_MODEL_TYPE: efficientnet_v2_s diff --git a/pkg/forklift-package.yml b/pkg/forklift-package.yml index dac0195..802c01c 100644 --- a/pkg/forklift-package.yml +++ b/pkg/forklift-package.yml @@ -31,3 +31,8 @@ features: paths: - /ps/streamlit-demo - /ps/streamlit-demo/* + torchvision-model-efficientnet-v2-s: + description: + Loads model weights for the efficientnet_v2_s model architecture instead of the default + model architecture (efficientnet_v2_m). + compose-files: [compose-torchvision-model-efficientnet-v2-s.yml]