Skip to content

Commit

Permalink
Merge pull request #3 from PlanktoScope/feature/selectable-model-types
Browse files Browse the repository at this point in the history
Make the torchvision model architecture selectable by env var
  • Loading branch information
ethanjli authored May 14, 2024
2 parents 8892bb3 + 735a2f6 commit 4432ff1
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 7 deletions.
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ To avoid errors later in running the docker container, please import your pretra

models/<model_name>

By default, models must match Torchvision's `efficientnet_v2_m` architecture; to use models with a
different architecture, you must specify the architecture with a `TORCHVISION_MODEL_TYPE`
environment variable passed into the app.

#### Building the Docker Image

Expand Down Expand Up @@ -121,9 +124,16 @@ forklift stage apply
After you have applied the pallet so that the streamlit demo app's container is running, you can
access the streamlit demo app from your web browser at <http://localhost/ps/streamlit-demo>.

Before you can use the streamlit demo app, you will need to download a classification model file
Before you can use the demo app, you will need to download a classification model weights file
(e.g. <https://github.com/PlanktoScope/streamlit-classification-app/releases/download/models%2Fdemo-1/effv2s_no_norm_DA+sh_20patience_256x256_50ep_loss.pth>)
into `~/.local/share/planktoscope/models`.
into `~/.local/share/planktoscope/models`; by default the model weights file must be for the
`efficientnet_v2_s` model architecture, but you can use the `efficientnet_v2_m` model architecture
instead by disabling the `torchvision-model-efficientnet-v2-s` feature flag of the pallet's
`apps/ps/streamlit-demo` package deployment.

Then you can upload input images
(e.g. <https://github.com/PlanktoScope/streamlit-classification-app/releases/download/models%2Fdemo-1/example-input-tots-ps-acq-20-02_49_37_288982.jpg>)
to the demo app.

## License
This project is licensed under the [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0).
Expand Down
16 changes: 12 additions & 4 deletions app_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import efficientnet_v2_m
from torchvision.models import efficientnet_v2_m, efficientnet_v2_s

model_types = {
"efficientnet_v2_m": efficientnet_v2_m,
"efficientnet_v2_s": efficientnet_v2_s,
}

############################################################################################
# Functions/variables to be used in the Streamlit app
Expand All @@ -42,11 +47,11 @@ def set_theme(theme):
st.markdown(light, unsafe_allow_html=True)

# Define the model loading function
def load_model(model_path):
def load_model(model_type, model_path):
# Load the model checkpoint (remove map_location if you have a GPU)
loaded_cpt = torch.load(model_path, map_location=torch.device('cpu'))
# Define the EfficientNet_V2_M model (by default, no pre-trained weights are used)
model = efficientnet_v2_m()
model = model_types[model_type]()
# Modify the classifier to match the number of classes in the dataset
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, 5)
# Load the state_dict in order to load the trained parameters
Expand Down Expand Up @@ -143,7 +148,10 @@ def main():
image_size = int(re.search(pattern, selected_model).group().split("x")[0])

# Load the selected model in pytorch
model = load_model(os.path.join("models", selected_model))
model = load_model(
os.getenv("TORCHVISION_MODEL_TYPE", "efficientnet_v2_m"),
os.path.join("models", selected_model),
)

# Load the class labels
class_labels = ["Acantharia", "Calanoida", "Neoceratium_petersii", "Ptychodiscus_noctiluca", "Undella"]
Expand Down
1 change: 1 addition & 0 deletions deployments/apps/ps/streamlit-demo.deploy.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package: /pkg
features:
- frontend
- torchvision-model-efficientnet-v2-s
disabled: false
4 changes: 4 additions & 0 deletions pkg/compose-torchvision-model-efficientnet-v2-s.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
services:
server:
environment:
TORCHVISION_MODEL_TYPE: efficientnet_v2_s
2 changes: 1 addition & 1 deletion pkg/compose.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
server:
image: ghcr.io/planktoscope/streamlit-classification-app:sha-15ff307
image: ghcr.io/planktoscope/streamlit-classification-app:sha-7083e51
volumes:
- ~/.local/share/planktoscope/models/:/app/models

Expand Down
5 changes: 5 additions & 0 deletions pkg/forklift-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ features:
paths:
- /ps/streamlit-demo
- /ps/streamlit-demo/*
torchvision-model-efficientnet-v2-s:
description:
Loads model weights for the efficientnet_v2_s model architecture instead of the default
model architecture (efficientnet_v2_m).
compose-files: [compose-torchvision-model-efficientnet-v2-s.yml]

0 comments on commit 4432ff1

Please sign in to comment.