Skip to content

Commit

Permalink
fixed export to have mask with labels from 0 to num. classes - 1; som…
Browse files Browse the repository at this point in the history
…e clean up
  • Loading branch information
mese79 committed Nov 1, 2024
1 parent dcab98c commit eca80e6
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/featureforest/_feature_extractor_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def save_storage(self):
storage_name = f"{image_layer_name}_{model_name}.hdf5"
# open the save dialog
selected_file, _filter = QFileDialog.getSaveFileName(
self, "Jug Lab", storage_name, "Embeddings Storage(*.hdf5)"
self, "FeatureForest", storage_name, "Embeddings Storage(*.hdf5)"
)
if selected_file is not None and len(selected_file) > 0:
if not selected_file.endswith(".hdf5"):
Expand Down
12 changes: 6 additions & 6 deletions src/featureforest/_segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
postprocess_with_sam_auto,
get_sam_auto_masks
)
from .exports import EXPORTERS
from .exports import EXPORTERS, reset_mask_labels
from .utils.pipeline_prediction import (
extract_predict
)
Expand Down Expand Up @@ -529,7 +529,7 @@ def sam_auto_post_checked(self, checked: bool):

def select_storage(self):
selected_file, _filter = QFileDialog.getOpenFileName(
self, "Jug Lab", ".", "Feature Storage(*.hdf5)"
self, "FeatureForest", ".", "Feature Storage(*.hdf5)"
)
if selected_file is not None and len(selected_file) > 0:
self.storage_textbox.setText(selected_file)
Expand Down Expand Up @@ -685,7 +685,7 @@ def train_model(self):

def load_rf_model(self):
selected_file, _filter = QFileDialog.getOpenFileName(
self, "Jug Lab", ".", "model(*.bin)"
self, "FeatureForest", ".", "model(*.bin)"
)
if len(selected_file) > 0:
# to suppress the sklearn InconsistentVersionWarning
Expand Down Expand Up @@ -722,7 +722,7 @@ def save_rf_model(self):
notif.show_info("There is no trained model!")
return
selected_file, _filter = QFileDialog.getSaveFileName(
self, "Jug Lab", ".", "model(*.bin)"
self, "FeatureForest", ".", "model(*.bin)"
)
if len(selected_file) > 0:
if not selected_file.endswith(".bin"):
Expand Down Expand Up @@ -947,7 +947,7 @@ def export_segmentation(self):
exporter = EXPORTERS[self.export_format_combo.currentText()]
# export_format = self.export_format_combo.currentText()
selected_file, _filter = QFileDialog.getSaveFileName(
self, "Jug Lab", ".", f"Segmentation(*.{exporter.extension})"
self, "FeatureForest", ".", f"Segmentation(*.{exporter.extension})"
)
if selected_file is None or len(selected_file) == 0:
return # user canceled export
Expand All @@ -967,7 +967,7 @@ def export_segmentation(self):

def select_stack(self):
selected_file, _filter = QFileDialog.getOpenFileName(
self, "Jug Lab", ".", "TIFF stack (*.tiff *.tif)"
self, "FeatureForest", ".", "TIFF stack (*.tiff *.tif)"
)
if selected_file is not None and len(selected_file) > 0:
# get stack info
Expand Down
8 changes: 7 additions & 1 deletion src/featureforest/exports/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import (
TiffExporter, NRRDExporter, NumpyExporter
TiffExporter, NRRDExporter, NumpyExporter,
reset_mask_labels
)


Expand All @@ -8,3 +9,8 @@
"nrrd": NRRDExporter(),
"numpy": NumpyExporter(),
}

__all__ = [
reset_mask_labels,
EXPORTERS
]
41 changes: 32 additions & 9 deletions src/featureforest/exports/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,29 @@
from napari.layers import Layer


def reset_mask_labels(mask_data: np.ndarray) -> np.ndarray:
"""Reset label values in the given mask:
for a binary mask values will be 0 and 255;
for a multi-class mask values will be reduced by one to match class index.
Args:
mask_data (np.ndarray): input mask
Returns:
np.ndarray: fixed mask
"""
mask_values = np.unique(mask_data)
if len(mask_values) == 2:
# this is a binary mask
mask_data[mask_data == min(mask_values)] = 0
mask_data[mask_data == max(mask_values)] = 255
else:
# reduce one from non-background pixels to match class index
mask_data[mask_data > 0] -= 1
assert (mask_data < 0).sum() == 0

return mask_data

class BaseExporter:
"""Base Exporter Class: all exporters should be a subclass of this class."""
def __init__(self, name: str = "Base Exporter", extension: str = "bin") -> None:
Expand All @@ -29,13 +52,9 @@ def __init__(self, name: str = "TIFF", extension: str = "tiff") -> None:
super().__init__(name, extension)

def export(self, layer: Layer, export_file: str) -> None:
tiff_data = layer.data.astype(np.uint8)
mask_values = np.unique(tiff_data)
if len(mask_values) == 2:
# this is a binary mask
tiff_data[tiff_data == min(mask_values)] = 0
tiff_data[tiff_data == max(mask_values)] = 255
imwrite(export_file, tiff_data)
mask_data = layer.data.copy().astype(np.uint8)
mask_data = reset_mask_labels(mask_data)
imwrite(export_file, mask_data)


class NRRDExporter(BaseExporter):
Expand All @@ -44,7 +63,9 @@ def __init__(self, name: str = "NRRD", extension: str = "nrrd") -> None:
super().__init__(name, extension)

def export(self, layer: Layer, export_file: str) -> None:
nrrd.write(export_file, np.transpose(layer.data))
mask_data = layer.data.copy().astype(np.uint8)
mask_data = reset_mask_labels(mask_data)
nrrd.write(export_file, np.transpose(mask_data))


class NumpyExporter(BaseExporter):
Expand All @@ -53,4 +74,6 @@ def __init__(self, name: str = "Numpy", extension: str = "npy") -> None:
super().__init__(name, extension)

def export(self, layer: Layer, export_file: str) -> None:
return np.save(export_file, layer.data)
mask_data = layer.data.copy().astype(np.uint8)
mask_data = reset_mask_labels(mask_data)
return np.save(export_file, mask_data)
2 changes: 1 addition & 1 deletion src/featureforest/postprocess/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def postprocess(
np.ndarray: post-processed segmentation image
"""
final_mask = np.zeros_like(segmentation_image, dtype=np.uint8)
# postprocessing gets done for each label's segments.
# postprocessing gets done for each label's mask separately.
class_labels = [c for c in np.unique(segmentation_image) if c > 0]
for label in class_labels:
# make a binary image for the label
Expand Down
2 changes: 1 addition & 1 deletion src/featureforest/postprocess/postprocess_with_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def postprocess_with_sam(
predictor = SamPredictor(get_light_hq_sam())

final_mask = np.zeros_like(segmentation_image, dtype=np.uint8)
# postprocessing gets done for each class segmentation.
# postprocessing gets done for each label's mask separately.
bg_label = 0
class_labels = [c for c in np.unique(segmentation_image) if c > bg_label]
for label in np_progress(
Expand Down
1 change: 0 additions & 1 deletion src/featureforest/utils/pipeline_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def extract_predict(
model_adapter: BaseModelAdapter,
storage_group: h5py.Group,
rf_model: RandomForestClassifier,
# result_dir: Path
) -> np.ndarray:
img_height, img_width = image.shape[:2]
patch_size = model_adapter.patch_size
Expand Down

0 comments on commit eca80e6

Please sign in to comment.