Skip to content

Commit

Permalink
Merge pull request #178 from Wiebke/retrieve-segmentation-results
Browse files Browse the repository at this point in the history
Retrieve segmentation results and populate job parameters
  • Loading branch information
Wiebke authored Mar 11, 2024
2 parents fbab404 + b7d5f2f commit e94fefc
Show file tree
Hide file tree
Showing 13 changed files with 461 additions and 114 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.env
.git
.gitignore
*-env/
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ USER_PASSWORD=<to-be-specified-per-deployment>
PREFECT_API_URL=http://prefect:4200/api
FLOW_NAME="Parent flow/launch_parent_flow"
TIMEZONE="US/Pacific"

# Environment variables for conda-based Prefect flows
CONDA_ENV_NAME="dlsia"
TRAIN_SCRIPT_PATH="src/train.py"
SEGMENT_SCRIPT_PATH="src/segment.py"
47 changes: 31 additions & 16 deletions assets/models.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"contents": [
{
"model_name": "DSLIA MSDNet",
"model_name": "MSDNet",
"version": "0.0.1",
"type": "supervised",
"user": "mlexchange team",
Expand Down Expand Up @@ -48,6 +48,7 @@
"name": "dilation_array",
"title": "Dilation Array",
"param_key": "dilation_array",
"value": "[1, 2, 4]",
"placeholder": "e.g. [1, 2, 4]",
"error": "Provide a list of ints for dilation",
"debounce": 1000,
Expand Down Expand Up @@ -230,6 +231,7 @@
"name": "weights",
"title": "Class Weights",
"param_key": "weights",
"value": "[1]",
"placeholder": "e.g [0.1, 0.4, 0.5]",
"error": "Provide a list with a float for each class",
"debounce": 1000,
Expand Down Expand Up @@ -312,20 +314,21 @@
"title": "Validation %",
"param_key": "val_pct",
"min": 0,
"max": 100,
"step": 5,
"value": 20,
"max": 1,
"step": 0.05,
"value": 0.2,
"precision": 2,
"marks": [
{
"value": 0,
"label": "0%"
},
{
"value": 50,
"value": 0.5,
"label": "50%"
},
{
"value": 100,
"value": 1,
"label": "100%"
}
],
Expand Down Expand Up @@ -434,7 +437,7 @@
"reference": "https://dlsia.readthedocs.io/en/latest/"
},
{
"model_name": "DSLIA TUNet",
"model_name": "TUNet",
"version": "0.0.1",
"type": "supervised",
"user": "mlexchange team",
Expand Down Expand Up @@ -653,6 +656,7 @@
"name": "weights",
"title": "Class Weights",
"param_key": "weights",
"value": "[1]",
"placeholder": "e.g [0.1, 0.4, 0.5]",
"error": "Provide a list with a float for each class",
"debounce": 1000,
Expand Down Expand Up @@ -735,16 +739,21 @@
"title": "Validation %",
"param_key": "val_pct",
"min": 0,
"max": 100,
"step": 5,
"value": 20,
"max": 1,
"step": 0.05,
"value": 0.2,
"precision": 2,
"marks": [
{
"value": 0,
"label": "0%"
},
{
"value": 100,
"value": 0.5,
"label": "50%"
},
{
"value": 1,
"label": "100%"
}
],
Expand Down Expand Up @@ -853,7 +862,7 @@
"reference": "https://dlsia.readthedocs.io/en/latest/"
},
{
"model_name": "DSLIA TUNet3+",
"model_name": "TUNet3+",
"version": "0.0.1",
"type": "supervised",
"user": "mlexchange team",
Expand Down Expand Up @@ -1080,6 +1089,7 @@
"name": "weights",
"title": "Class Weights",
"param_key": "weights",
"value": "[1]",
"placeholder": "e.g [0.1, 0.4, 0.5]",
"error": "Provide a list with a float for each class",
"debounce": 1000,
Expand Down Expand Up @@ -1162,16 +1172,21 @@
"title": "Validation %",
"param_key": "val_pct",
"min": 0,
"max": 100,
"step": 5,
"value": 20,
"max": 1,
"step": 0.05,
"value": 0.2,
"precision": 2,
"marks": [
{
"value": 0,
"label": "0%"
},
{
"value": 100,
"value": 0.5,
"label": "50%"
},
{
"value": 1,
"label": "100%"
}
],
Expand Down
90 changes: 42 additions & 48 deletions callbacks/control_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from components.parameter_items import ParameterItems
from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS
from utils.annotations import Annotations
from utils.data_utils import models, tiled_datasets, tiled_masks, tiled_results
from utils.data_utils import models, tiled_datasets, tiled_masks
from utils.plot_utils import generate_notification, generate_notification_bg_icon_col

# TODO - temporary local file path and user for annotation saving and exporting
Expand Down Expand Up @@ -230,16 +230,18 @@ def annotation_mode(
patched_figure["layout"]["dragmode"] = "drawrect"
annotation_store["dragmode"] = "drawrect"
styles[trigger] = active

elif trigger == "pan-and-zoom" and pan_and_zoom > 0:
patched_figure["layout"]["dragmode"] = "pan"
annotation_store["dragmode"] = "pan"
styles[trigger] = active

# disable shape editing when in pan/zoom mode
for shape in fig["layout"]["shapes"]:
shape["editable"] = trigger != "pan-and-zoom" and pan_and_zoom > 0
patched_figure["layout"]["shapes"] = fig["layout"]["shapes"]
# if no shapes have been added yet,
# none need to be set to not editable
if "shapes" in fig["layout"]:
for shape in fig["layout"]["shapes"]:
shape["editable"] = trigger != "pan-and-zoom" and pan_and_zoom > 0
patched_figure["layout"]["shapes"] = fig["layout"]["shapes"]
return (
patched_figure,
styles["closed-freeform"],
Expand Down Expand Up @@ -853,64 +855,46 @@ def open_controls_drawer(n_clicks, is_opened):
return no_update, no_update


@callback(Output("project-name-src", "data"), Input("refresh-tiled", "n_clicks"))
def refresh_data_client(refresh_tiled):
if refresh_tiled:
tiled_datasets.refresh_data_client()
data_options = [
item for item in tiled_datasets.get_data_project_names() if "seg" not in item
]
return data_options


@callback(
Output("result-selector", "data"),
Output("result-selector", "value"),
Output("result-selector", "disabled"),
Output("show-result-overlay-toggle", "checked"),
Output("show-result-overlay-toggle", "disabled"),
Output("seg-result-opacity-slider", "disabled"),
Output("project-name-src", "data"),
Input("project-name-src", "value"),
Input("refresh-tiled", "n_clicks"),
Input("show-result-overlay-toggle", "checked"),
State("result-selector", "disabled"),
Input("seg-results-train-store", "data"),
Input("seg-results-inference-store", "data"),
State("seg-result-opacity-slider", "disabled"),
)
def populate_classification_results(
image_src, refresh_tiled, toggle, dropdown_enabled, slider_enabled
def update_result_controls(
toggle, seg_result_train, seg_result_inference, slider_disabled
):
if refresh_tiled:
tiled_datasets.refresh_data_client()

data_options = [
item for item in tiled_datasets.get_data_project_names() if "seg" not in item
]
results = []
value = None
checked = False
disabled_dropdown = True
disabled_toggle = True
disabled_slider = True
disable_toggle = True
disable_slider = True
# Disable opacity slider if result overlay is unchecked
if ctx.triggered_id == "show-result-overlay-toggle":
results = no_update
value = no_update
checked = no_update
disabled_dropdown = dropdown_enabled
disabled_toggle = False
disabled_slider = slider_enabled
# Must have been enabled to be source of trigger
disable_toggle = no_update
disable_slider = not slider_disabled
else:
# TODO: Match by mask uid instead of image_src
results = [
item
for item in tiled_results.get_data_project_names()
if ("seg" in item and image_src in item)
]
if results:
value = results[0]
disabled_dropdown = False
if seg_result_train or seg_result_inference:
checked = False
disabled_toggle = False
disabled_slider = False

disable_toggle = False
disable_slider = False
return (
results,
value,
disabled_dropdown,
checked,
disabled_toggle,
disabled_slider,
data_options,
disable_toggle,
disable_slider,
)


Expand Down Expand Up @@ -961,6 +945,10 @@ def update_model_parameters(model_name):
),
)
def validate_class_weights(all_annotation_classes, weights):

if weights is None:
return "Provide a list with a float for each class"

parsed_weights = weights.strip("[]").split(",")
try:
parsed_weights = [float(weight.strip()) for weight in parsed_weights]
Expand Down Expand Up @@ -996,11 +984,17 @@ def validate_class_weights(all_annotation_classes, weights):
),
)
def validate_dilation_array(dilation_array):

if dilation_array is None:
return "Provide a list of ints for dilation"

parsed_dilation_array = dilation_array.strip("[]").split(",")
try:
parsed_dilation_array = [
int(array_entry.strip()) for array_entry in parsed_dilation_array
]
if len(parsed_dilation_array) == 0:
return "Provide a list of ints for dilation"
# Check if all elements in the list are floats
return False
except ValueError:
Expand Down
34 changes: 24 additions & 10 deletions callbacks/image_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dash.exceptions import PreventUpdate

from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEYBINDS
from utils.data_utils import tiled_datasets, tiled_masks, tiled_results
from utils.data_utils import tiled_datasets, tiled_results
from utils.plot_utils import (
create_viewfinder,
downscale_view,
Expand Down Expand Up @@ -70,7 +70,8 @@ def hide_show_segmentation_overlay(toggle_seg_result, opacity):
State("image-metadata", "data"),
State("screen-size", "data"),
State("current-class-selection", "data"),
State("result-selector", "value"),
State("seg-results-train-store", "data"),
State("seg-results-inference-store", "data"),
State("seg-result-opacity-slider", "value"),
State("image-viewer", "figure"),
prevent_initial_call=True,
Expand All @@ -85,7 +86,8 @@ def render_image(
image_metadata,
screen_size,
current_color,
seg_result_selection,
seg_result_train,
seg_result_inference,
opacity,
fig,
):
Expand Down Expand Up @@ -118,13 +120,25 @@ def render_image(
and ctx.triggered_id == "show-result-overlay-toggle"
):
return [dash.no_update] * 7 + ["hidden"]
annotation_indices = tiled_masks.get_annotated_segmented_results()
if str(image_idx + 1) in annotation_indices:
# Will not return an error since we already checked if image_idx+1 is in the list
mapped_index = annotation_indices.index(str(image_idx + 1))
result = tiled_results.get_data_sequence_by_name(seg_result_selection)[
mapped_index
]
# Check if the stored results are for the current project and image
if seg_result_train or seg_result_inference:
seg_result = (
seg_result_inference if seg_result_inference else seg_result_train
)
if "mask_idx" in seg_result and seg_result["mask_idx"] is not None:
annotation_indices = seg_result["mask_idx"]
if str(image_idx) in annotation_indices:
# Will not return an error since we already checked if image_idx is in the list
mapped_index = annotation_indices.index(str(image_idx))
result = tiled_results.get_data_by_trimmed_uri(
seg_result["seg_result_trimmed_uri"], slice=mapped_index
)
else:
result = None
else:
result = tiled_results.get_data_by_trimmed_uri(
seg_result["seg_result_trimmed_uri"], slice=image_idx
)
else:
result = None
else:
Expand Down
Loading

0 comments on commit e94fefc

Please sign in to comment.