diff --git a/README.md b/README.md index 882ecc5..9ad2ea3 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,9 @@ To avoid errors later in running the docker container, please import your pretra models/ +By default, models must match Torchvision's `efficientnet_v2_m` architecture; to use models with a +different architecture, you must specify the architecture with a `TORCHVISION_MODEL_TYPE` +environment variable passed into the app. #### Building the Docker Image @@ -121,9 +124,16 @@ forklift stage apply After you have applied the pallet so that the streamlit demo app's container is running, you can access the streamlit demo app from your web browser at . -Before you can use the streamlit demo app, you will need to download a classification model file +Before you can use the demo app, you will need to download a classification model weights file (e.g. ) -into `~/.local/share/planktoscope/models`. +into `~/.local/share/planktoscope/models`; by default the model weights file must be for the +`efficientnet_v2_s` model architecture, but you can use the `efficientnet_v2_m` model architecture +instead by disabling the `torchvision-model-efficientnet-v2-s` feature flag of the pallet's +`apps/ps/streamlit-demo` package deployment. + +Then you can upload input images +(e.g. ) +to the demo app. ## License This project is licensed under the [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0). diff --git a/app_model.py b/app_model.py index e5c7c70..c988da3 100644 --- a/app_model.py +++ b/app_model.py @@ -15,7 +15,12 @@ import torch import torch.nn as nn import torchvision.transforms as transforms -from torchvision.models import efficientnet_v2_m +from torchvision.models import efficientnet_v2_m, efficientnet_v2_s + +model_types = { + "efficientnet_v2_m": efficientnet_v2_m, + "efficientnet_v2_s": efficientnet_v2_s, +} ############################################################################################ # Functions/variables to be used in the Streamlit app @@ -42,11 +47,11 @@ def set_theme(theme): st.markdown(light, unsafe_allow_html=True) # Define the model loading function -def load_model(model_path): +def load_model(model_type, model_path): # Load the model checkpoint (remove map_location if you have a GPU) loaded_cpt = torch.load(model_path, map_location=torch.device('cpu')) # Define the EfficientNet_V2_M model (by default, no pre-trained weights are used) - model = efficientnet_v2_m() + model = model_types[model_type]() # Modify the classifier to match the number of classes in the dataset model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, 5) # Load the state_dict in order to load the trained parameters @@ -143,7 +148,10 @@ def main(): image_size = int(re.search(pattern, selected_model).group().split("x")[0]) # Load the selected model in pytorch - model = load_model(os.path.join("models", selected_model)) + model = load_model( + os.getenv("TORCHVISION_MODEL_TYPE", "efficientnet_v2_m"), + os.path.join("models", selected_model), + ) # Load the class labels class_labels = ["Acantharia", "Calanoida", "Neoceratium_petersii", "Ptychodiscus_noctiluca", "Undella"] diff --git a/deployments/apps/ps/streamlit-demo.deploy.yml b/deployments/apps/ps/streamlit-demo.deploy.yml index 8e04969..6625a39 100644 --- a/deployments/apps/ps/streamlit-demo.deploy.yml +++ b/deployments/apps/ps/streamlit-demo.deploy.yml @@ -1,4 +1,5 @@ package: /pkg features: - frontend + - torchvision-model-efficientnet-v2-s disabled: false diff --git a/pkg/compose-torchvision-model-efficientnet-v2-s.yml b/pkg/compose-torchvision-model-efficientnet-v2-s.yml new file mode 100644 index 0000000..b518427 --- /dev/null +++ b/pkg/compose-torchvision-model-efficientnet-v2-s.yml @@ -0,0 +1,4 @@ +services: + server: + environment: + TORCHVISION_MODEL_TYPE: efficientnet_v2_s diff --git a/pkg/compose.yml b/pkg/compose.yml index 58dc10f..6b6dcac 100644 --- a/pkg/compose.yml +++ b/pkg/compose.yml @@ -1,6 +1,6 @@ services: server: - image: ghcr.io/planktoscope/streamlit-classification-app:sha-15ff307 + image: ghcr.io/planktoscope/streamlit-classification-app:sha-7083e51 volumes: - ~/.local/share/planktoscope/models/:/app/models diff --git a/pkg/forklift-package.yml b/pkg/forklift-package.yml index dac0195..802c01c 100644 --- a/pkg/forklift-package.yml +++ b/pkg/forklift-package.yml @@ -31,3 +31,8 @@ features: paths: - /ps/streamlit-demo - /ps/streamlit-demo/* + torchvision-model-efficientnet-v2-s: + description: + Loads model weights for the efficientnet_v2_s model architecture instead of the default + model architecture (efficientnet_v2_m). + compose-files: [compose-torchvision-model-efficientnet-v2-s.yml]