diff --git a/pyproject.toml b/pyproject.toml index ed5d696..eb16aaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,8 +29,10 @@ classifiers = [ dependencies = [ "cellpose", "confuse", + "pyimagej", "rich", "scikit-image", + "scyjava", "tqdm", "typer" ] diff --git a/src/faim_wako_searchfirst/segment.py b/src/faim_wako_searchfirst/segment.py index 09c4bbe..8fe15f0 100644 --- a/src/faim_wako_searchfirst/segment.py +++ b/src/faim_wako_searchfirst/segment.py @@ -72,3 +72,33 @@ def cellpose( **kwargs, ) return mask + + +def weka_classifier( + img, + classifier_path: Path, + logger=logging, +): + """Apply a classifier.model file. + + Use Trainable Weka Segmentation (via pyimagej) + to apply a trained classifier to the input image. + """ + import imagej + from scyjava import jimport + + logger.info("Initializing ImageJ...") + ij = imagej.init("sc.fiji:fiji:2.15.0") + + WekaSegmentation = jimport("trainableSegmentation.WekaSegmentation") + + logger.info("Starting Trainable Segmentation plugin...") + segmenter = WekaSegmentation(ij.py.to_imageplus(img)) + logger.info(f"Loading classifier from {classifier_path}.") + segmenter.loadClassifier(str(classifier_path)) + segmenter.applyClassifier(False) + # get result imp + # convert imp from_java + result = ij.py.from_java(segmenter.getClassifiedImage()) + ij.dispose() + return label(result, background=1).astype(np.uint16) diff --git a/tests/resources/classifier.model b/tests/resources/classifier.model new file mode 100644 index 0000000..8662e54 Binary files /dev/null and b/tests/resources/classifier.model differ diff --git a/tests/test_weka_classifier.py b/tests/test_weka_classifier.py new file mode 100644 index 0000000..d391602 --- /dev/null +++ b/tests/test_weka_classifier.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: 2024 Friedrich Miescher Institute for Biomedical Research (FMI), Basel (Switzerland) +# +# SPDX-License-Identifier: MIT +"""Test segment.weka_classifier functionality.""" +from pathlib import Path + +import numpy as np +import pytest +from faim_wako_searchfirst.segment import weka_classifier +from skimage.io import imread + +SAMPLE_IMAGE_NAME = "1649_1109_0003_Amp5-1_B_20070424_D11_w2_E2E74170-1089-4BE7-B8A3-6EC6852579B7.png" + + +@pytest.fixture +def _classifier_path(): + return Path("tests/resources/classifier.model") + + +@pytest.fixture +def _sample_image(): + return imread(Path("tests/resources") / SAMPLE_IMAGE_NAME) + + +def test_weka_classifier(_sample_image, _classifier_path): + """Directly call weka_classifier.""" + result = weka_classifier(_sample_image, _classifier_path) + assert result.shape == (442, 442) + assert np.count_nonzero(result == 1) == 3 + assert np.count_nonzero(result == 2) == 8 + assert np.count_nonzero(result == 3) == 831