Skip to content

Commit

Permalink
Try to make the torchvision model architecture selectable by env var
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanjli committed May 13, 2024
1 parent 8892bb3 commit 8801124
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 6 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ To avoid errors later in running the docker container, please import your pretra

models/<model_name>

By default, models must match Torchvision's `efficientnet_v2_m` architecture; to use models with a
different architecture, you must specify the architecture with a `TORCHVISION_MODEL_TYPE`
environment variable passed into the app.

#### Building the Docker Image

Expand Down Expand Up @@ -121,9 +124,12 @@ forklift stage apply
After you have applied the pallet so that the streamlit demo app's container is running, you can
access the streamlit demo app from your web browser at <http://localhost/ps/streamlit-demo>.

Before you can use the streamlit demo app, you will need to download a classification model file
Before you can use the demo app, you will need to download a classification model weights file
(e.g. <https://github.com/PlanktoScope/streamlit-classification-app/releases/download/models%2Fdemo-1/effv2s_no_norm_DA+sh_20patience_256x256_50ep_loss.pth>)
into `~/.local/share/planktoscope/models`.
into `~/.local/share/planktoscope/models`; by default the model weights file must be for the
`efficientnet_v2_s` model architecture, but you can use the `efficientnet_v2_m` model architecture
instead by disabling the `torchvision-model-efficientnet-v2-s` feature flag of the pallet's
`apps/ps/streamlit-demo` package deployment.

## License
This project is licensed under the [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0).
Expand Down
16 changes: 12 additions & 4 deletions app_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import efficientnet_v2_m
from torchvision.models import efficientnet_v2_m, efficientnet_v2_s

model_types = {
"efficientnet_v2_m": efficientnet_v2_m,
"efficientnet_v2_s": efficientnet_v2_s,
}

############################################################################################
# Functions/variables to be used in the Streamlit app
Expand All @@ -42,11 +47,11 @@ def set_theme(theme):
st.markdown(light, unsafe_allow_html=True)

# Define the model loading function
def load_model(model_path):
def load_model(model_type, model_path):
# Load the model checkpoint (remove map_location if you have a GPU)
loaded_cpt = torch.load(model_path, map_location=torch.device('cpu'))
# Define the EfficientNet_V2_M model (by default, no pre-trained weights are used)
model = efficientnet_v2_m()
model = model_types[model_type]()
# Modify the classifier to match the number of classes in the dataset
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, 5)
# Load the state_dict in order to load the trained parameters
Expand Down Expand Up @@ -143,7 +148,10 @@ def main():
image_size = int(re.search(pattern, selected_model).group().split("x")[0])

# Load the selected model in pytorch
model = load_model(os.path.join("models", selected_model))
model = load_model(
os.getenv("TORCHVISION_MODEL_TYPE", "efficientnet_v2_m"),
os.path.join("models", selected_model),
)

# Load the class labels
class_labels = ["Acantharia", "Calanoida", "Neoceratium_petersii", "Ptychodiscus_noctiluca", "Undella"]
Expand Down
1 change: 1 addition & 0 deletions deployments/apps/ps/streamlit-demo.deploy.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package: /pkg
features:
- frontend
- torchvision-model-efficientnet-v2-s
disabled: false
4 changes: 4 additions & 0 deletions pkg/compose-torchvision-model-efficientnet-v2-s.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
services:
server:
environment:
TORCHVISION_MODEL_TYPE: efficientnet_v2_s
5 changes: 5 additions & 0 deletions pkg/forklift-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ features:
paths:
- /ps/streamlit-demo
- /ps/streamlit-demo/*
torchvision-model-efficientnet-v2-s:
description:
Loads model weights for the efficientnet_v2_s model architecture instead of the default
model architecture (efficientnet_v2_m).
compose-files: [compose-torchvision-model-efficientnet-v2-s.yml]

0 comments on commit 8801124

Please sign in to comment.