diff --git a/dacapo/experiments/architectures/__init__.py b/dacapo/experiments/architectures/__init__.py index 6125893c1..f21fe05f1 100644 --- a/dacapo/experiments/architectures/__init__.py +++ b/dacapo/experiments/architectures/__init__.py @@ -5,3 +5,4 @@ DummyArchitecture, ) # noqa from .cnnectome_unet_config import CNNectomeUNetConfig, CNNectomeUNet # noqa +from .cellpose_unet_config import CellposUNetConfig, CellposeUnet # noqa diff --git a/dacapo/experiments/architectures/cellpose_unet.py b/dacapo/experiments/architectures/cellpose_unet.py new file mode 100644 index 000000000..cbdf707ad --- /dev/null +++ b/dacapo/experiments/architectures/cellpose_unet.py @@ -0,0 +1,75 @@ +from cellpose.resnet_torch import CPnet +from .architecture import Architecture +from funlib.geometry import Coordinate + + +# example +# nout = 4 +# sz = 3 +# self.net = CPnet( +# nbase, nout, sz, mkldnn=False, conv_3D=True, max_pool=True, diam_mean=30.0 +# ) +# currently the input channels are embedded in nbdase, but they should be passed as a separate parameternbase = [in_chan, 32, 64, 128, 256] +class CellposeUnet(Architecture): + def __init__(self, architecture_config): + super().__init__() + self._input_shape = Coordinate(architecture_config.input_shape) + self._nbase = architecture_config.nbase + self._sz = self._input_shape.dims + self._eval_shape_increase = Coordinate((0,) * self._sz) + self._nout = architecture_config.nout + print("conv_3D:", architecture_config.conv_3D) + self.unet = CPnet( + architecture_config.nbase, + architecture_config.nout, + self._sz, + architecture_config.mkldnn, + architecture_config.conv_3D, + architecture_config.max_pool, + architecture_config.diam_mean, + ) + print(self.unet) + + def forward(self, data): + """ + Forward pass of the CPnet model. + + Args: + data (torch.Tensor): Input data. + + Returns: + tuple: A tuple containing the output tensor, style tensor, and downsampled tensors. + """ + if self.unet.mkldnn: + data = data.to_mkldnn() + T0 = self.unet.downsample(data) + if self.unet.mkldnn: + style = self.unet.make_style(T0[-1].to_dense()) + else: + style = self.unet.make_style(T0[-1]) + # style0 = style + if not self.unet.style_on: + style = style * 0 + T1 = self.unet.upsample(style, T0, self.unet.mkldnn) + # head layer + # T1 = self.unet.output(T1) + if self.unet.mkldnn: + T0 = [t0.to_dense() for t0 in T0] + T1 = T1.to_dense() + return T1 + + @property + def input_shape(self): + return self._input_shape + + @property + def num_in_channels(self) -> int: + return self._nbase[0] + + @property + def num_out_channels(self) -> int: + return self._nout + + @property + def eval_shape_increase(self): + return self._eval_shape_increase diff --git a/dacapo/experiments/architectures/cellpose_unet_config.py b/dacapo/experiments/architectures/cellpose_unet_config.py new file mode 100644 index 000000000..63d71c83d --- /dev/null +++ b/dacapo/experiments/architectures/cellpose_unet_config.py @@ -0,0 +1,41 @@ +import attr + +from .architecture_config import ArchitectureConfig +from .cellpose_unet import CellposeUnet + +from funlib.geometry import Coordinate + +from typing import List, Optional + + +@attr.s +class CellposUNetConfig(ArchitectureConfig): + """This class configures the CellPose based on + https://github.com/MouseLand/cellpose/blob/main/cellpose/resnet_torch.py + """ + + architecture_type = CellposeUnet + + input_shape: Coordinate = attr.ib( + metadata={ + "help_text": "The shape of the data passed into the network during training." + } + ) + nbase: List[int] = attr.ib( + metadata={ + "help_text": "List of integers representing the number of channels in each layer of the downsample path." + } + ) + nout: int = attr.ib(metadata={"help_text": "Number of output channels."}) + mkldnn: Optional[bool] = attr.ib( + default=False, metadata={"help_text": "Whether to use MKL-DNN acceleration."} + ) + conv_3D: bool = attr.ib( + default=False, metadata={"help_text": "Whether to use 3D convolution."} + ) + max_pool: Optional[bool] = attr.ib( + default=True, metadata={"help_text": "Whether to use max pooling."} + ) + diam_mean: Optional[float] = attr.ib( + default=30.0, metadata={"help_text": "Mean diameter of the cells."} + ) diff --git a/dacapo/experiments/tasks/cellpose_task.py b/dacapo/experiments/tasks/cellpose_task.py new file mode 100644 index 000000000..103976b2f --- /dev/null +++ b/dacapo/experiments/tasks/cellpose_task.py @@ -0,0 +1,23 @@ +from .evaluators import BinarySegmentationEvaluator +from .losses import CellposeLoss +from .post_processors import ThresholdPostProcessor +from .predictors import CellposePredictor +from .task import Task + + +class CellposeTask(Task): + def __init__(self, task_config): + self.predictor = CellposePredictor( + channels=task_config.channels, + scale_factor=task_config.scale_factor, + mask_distances=task_config.mask_distances, + clipmin=task_config.clipmin, + clipmax=task_config.clipmax, + ) + self.loss = CellposeLoss() + self.post_processor = ThresholdPostProcessor() + self.evaluator = BinarySegmentationEvaluator( + clip_distance=task_config.clip_distance, + tol_distance=task_config.tol_distance, + channels=task_config.channels, + ) diff --git a/dacapo/experiments/tasks/losses/__init__.py b/dacapo/experiments/tasks/losses/__init__.py index f1db3586b..b07e7c3dd 100644 --- a/dacapo/experiments/tasks/losses/__init__.py +++ b/dacapo/experiments/tasks/losses/__init__.py @@ -3,3 +3,4 @@ from .loss import Loss # noqa from .affinities_loss import AffinitiesLoss # noqa from .hot_distance_loss import HotDistanceLoss # noqa +from .cellpose_loss import CellposeLoss # noqa diff --git a/dacapo/experiments/tasks/losses/cellpose_loss.py b/dacapo/experiments/tasks/losses/cellpose_loss.py new file mode 100644 index 000000000..16c7d9daf --- /dev/null +++ b/dacapo/experiments/tasks/losses/cellpose_loss.py @@ -0,0 +1,18 @@ +from .loss import Loss +import torch +from torch import nn + +# TODO check support weights + + +class CellposeLoss(Loss): + def compute(self, prediction, target, weights=None): + """loss function between true labels target and prediction prediction""" + criterion = nn.MSELoss(reduction="mean") + criterion2 = nn.BCEWithLogitsLoss(reduction="mean") + veci = 5.0 * target[:, 1:] + loss = criterion(prediction[:, :-1], veci) + loss /= 2.0 + loss2 = criterion2(prediction[:, -1], (target[:, 0] > 0.5).float()) + loss = loss + loss2 + return loss diff --git a/dacapo/experiments/tasks/post_processors/cellpose_post_processor.py b/dacapo/experiments/tasks/post_processors/cellpose_post_processor.py new file mode 100644 index 000000000..382a360cb --- /dev/null +++ b/dacapo/experiments/tasks/post_processors/cellpose_post_processor.py @@ -0,0 +1,40 @@ +from .cellpose_post_processor_parameters import CellposePostProcessorParameters +from .post_processor import PostProcessor +from dacapo.store import LocalArrayIdentifier, ZarrArray +import numpy as np +import zarr + +from typing import Iterable + +from cellpose.dynamics import compute_masks + +# https://github.com/MouseLand/cellpose/blob/54b14fe567d885db293280b9b8d68dc50703d219/cellpose/models.py#L608 + + +class CellposePostProcessor(PostProcessor): + def __init__(self, detection_threshold: float): + self.detection_threshold = detection_threshold + + def enumerate_parameters(self) -> Iterable[CellposePostProcessorParameters]: + """Enumerate all possible parameters of this post-processor. Should + return instances of ``PostProcessorParameters``.""" + + for i, min_size in enumerate(range(1, 11)): + yield CellposePostProcessorParameters(id=i, min_size=min_size) + + def set_prediction(self, prediction_array_identifier: LocalArrayIdentifier): + self.prediction_array = ZarrArray.open_from_identifier( + prediction_array_identifier + ) + + def process(self, parameters, output_array_identifier): + # store some dummy data + f = zarr.open(str(output_array_identifier.container), "a") + f[output_array_identifier.dataset] = compute_masks( + self.prediction_array.data[:-1] / 5.0, + self.prediction_array.data[-1], + niter=200, + cellprob_threshold=self.detection_threshold, + do_3D=True, + min_size=parameters.min_size, + )[0] diff --git a/dacapo/experiments/tasks/post_processors/cellpose_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/cellpose_post_processor_parameters.py new file mode 100644 index 000000000..a1bb45642 --- /dev/null +++ b/dacapo/experiments/tasks/post_processors/cellpose_post_processor_parameters.py @@ -0,0 +1,8 @@ +from .post_processor_parameters import PostProcessorParameters +import attr + + +# TODO +@attr.s(frozen=True) +class CellposePostProcessorParameters(PostProcessorParameters): + min_size: int = attr.ib() diff --git a/dacapo/experiments/tasks/predictors/__init__.py b/dacapo/experiments/tasks/predictors/__init__.py index 7be8dcf90..989559e8c 100644 --- a/dacapo/experiments/tasks/predictors/__init__.py +++ b/dacapo/experiments/tasks/predictors/__init__.py @@ -5,3 +5,4 @@ from .affinities_predictor import AffinitiesPredictor # noqa from .inner_distance_predictor import InnerDistancePredictor # noqa from .hot_distance_predictor import HotDistancePredictor # noqa +from .cellpose_predictor import CellposePredictor # noqa diff --git a/dacapo/experiments/tasks/predictors/cellpose_predictor.py b/dacapo/experiments/tasks/predictors/cellpose_predictor.py new file mode 100644 index 000000000..b08043d98 --- /dev/null +++ b/dacapo/experiments/tasks/predictors/cellpose_predictor.py @@ -0,0 +1,161 @@ +from .predictor import Predictor +from dacapo.experiments import Model +from dacapo.experiments.arraytypes import DistanceArray +from dacapo.experiments.datasplits.datasets.arrays import NumpyArray +from dacapo.utils.balance_weights import balance_weights + +from funlib.geometry import Coordinate + +from scipy.ndimage.morphology import distance_transform_edt +import numpy as np +import torch + +import logging +from typing import List +from cellpose.dynamics import masks_to_flows_gpu_3d + +logger = logging.getLogger(__name__) + + +# TODO currently CPnet have nout which is the head of the network, check how to change it in the predictor +class CellposePredictor(Predictor): + """ + Predict signed distances for a binary segmentation task. + Distances deep within background are pushed to -inf, distances deep within + the foreground object are pushed to inf. After distances have been + calculated they are passed through a tanh so that distances saturate at +-1. + Multiple classes can be predicted via multiple distance channels. The names + of each class that is being segmented can be passed in as a list of strings + in the channels argument. + """ + + def __init__( + self, + channels: List[str], + scale_factor: float, + mask_distances: bool, + clipmin: float = 0.05, + clipmax: float = 0.95, + ): + self.channels = channels + self.norm = "tanh" + self.dt_scale_factor = scale_factor + self.mask_distances = mask_distances + + self.max_distance = 1 * scale_factor + self.epsilon = 5e-2 + self.threshold = 0.8 + self.clipmin = clipmin + self.clipmax = clipmax + + @property + def embedding_dims(self): + return len(self.channels) + + def create_model(self, architecture): + if isinstance(architecture, CellposeUnet): + head = torch.nn.Identity() + + return Model(architecture, torch.nn.Identity()) + + def create_target(self, gt): + flows, _ = masks_to_flows_gpu_3d(gt) + # difussion = self.process( + # gt.data, gt.voxel_size, self.norm, self.dt_scale_factor + # ) + return NumpyArray.from_np_array( + flows, + gt.roi, + gt.voxel_size, + gt.axes, + ) + + def create_weight(self, gt, target, mask, moving_class_counts=None): + # balance weights independently for each channel + + weights, moving_class_counts = balance_weights( + gt[target.roi], + 2, + slab=tuple(1 if c == "c" else -1 for c in gt.axes), + masks=[mask[target.roi]], + moving_counts=moving_class_counts, + clipmin=self.clipmin, + clipmax=self.clipmax, + ) + return ( + NumpyArray.from_np_array( + weights, + gt.roi, + gt.voxel_size, + gt.axes, + ), + moving_class_counts, + ) + + @property + def output_array_type(self): + return DistanceArray(self.embedding_dims) + + def process( + self, + labels: np.ndarray, + voxel_size: Coordinate, + normalize=None, + normalize_args=None, + ): + all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 + for ii, channel in enumerate(labels): + boundaries = self.__find_boundaries(channel) + + # mark boundaries with 0 (not 1) + boundaries = 1.0 - boundaries + + if np.sum(boundaries == 0) == 0: + max_distance = min( + dim * vs / 2 for dim, vs in zip(channel.shape, voxel_size) + ) + if np.sum(channel) == 0: + distances = -np.ones(channel.shape, dtype=np.float32) * max_distance + else: + distances = np.ones(channel.shape, dtype=np.float32) * max_distance + else: + # get distances (voxel_size/2 because image is doubled) + distances = distance_transform_edt( + boundaries, sampling=tuple(float(v) / 2 for v in voxel_size) + ) + distances = distances.astype(np.float32) + + # restore original shape + downsample = (slice(None, None, 2),) * len(voxel_size) + distances = distances[downsample] + + # todo: inverted distance + distances[channel == 0] = -distances[channel == 0] + + if normalize is not None: + distances = self.__normalize(distances, normalize, normalize_args) + + all_distances[ii] = distances + + return all_distances + + def __normalize(self, distances, norm, normalize_args): + if norm == "tanh": + scale = normalize_args + return np.tanh(distances / scale) + else: + raise ValueError("Only tanh is supported for normalization") + + def gt_region_for_roi(self, target_spec): + if self.mask_distances: + gt_spec = target_spec.copy() + gt_spec.roi = gt_spec.roi.grow( + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + ).snap_to_grid(gt_spec.voxel_size, mode="shrink") + else: + gt_spec = target_spec.copy() + return gt_spec + + def padding(self, gt_voxel_size: Coordinate) -> Coordinate: + return Coordinate((self.max_distance,) * gt_voxel_size.dims) diff --git a/pyproject.toml b/pyproject.toml index 0ab64cdff..48eae08b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "numpy-indexed", "click", "pyyaml", + "cellpose", "scipy", "upath", "boto3", @@ -63,7 +64,7 @@ dependencies = [ # extras # https://peps.python.org/pep-0621/#dependencies-optional-dependencies [project.optional-dependencies] -test = ["pytest==7.4.4", "pytest-cov", "pytest-lazy-fixture"] +test = ["pytest", "pytest-cov", "pytest-lazy-fixtures"] dev = [ "black", "mypy", diff --git a/tests/components/test_arrays.py b/tests/components/test_arrays.py index d62dcb973..d91863ad7 100644 --- a/tests/components/test_arrays.py +++ b/tests/components/test_arrays.py @@ -3,15 +3,15 @@ from dacapo.store.create_store import create_config_store import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf @pytest.mark.parametrize( "array_config", [ - lazy_fixture("cellmap_array"), - lazy_fixture("zarr_array"), - lazy_fixture("dummy_array"), + lf("cellmap_array"), + lf("zarr_array"), + lf("dummy_array"), ], ) def test_array_api(options, array_config): diff --git a/tests/components/test_gp_arraysource.py b/tests/components/test_gp_arraysource.py index 7ff626036..69fee515f 100644 --- a/tests/components/test_gp_arraysource.py +++ b/tests/components/test_gp_arraysource.py @@ -5,15 +5,15 @@ import gunpowder as gp import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf @pytest.mark.parametrize( "array_config", [ - lazy_fixture("cellmap_array"), - lazy_fixture("zarr_array"), - lazy_fixture("dummy_array"), + lf("cellmap_array"), + lf("zarr_array"), + lf("dummy_array"), ], ) def test_gp_dacapo_array_source(array_config): diff --git a/tests/components/test_trainers.py b/tests/components/test_trainers.py index 172a89b75..f3f9b07ac 100644 --- a/tests/components/test_trainers.py +++ b/tests/components/test_trainers.py @@ -3,14 +3,14 @@ from dacapo.store.create_store import create_config_store import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf @pytest.mark.parametrize( "trainer_config", [ - lazy_fixture("dummy_trainer"), - lazy_fixture("gunpowder_trainer"), + lf("dummy_trainer"), + lf("gunpowder_trainer"), ], ) def test_trainer( diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 3ea282acc..ca5ef6365 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -1,11 +1,11 @@ from .db import options -from .architectures import dummy_architecture +from .architectures import dummy_architecture, cellpose_architecture from .arrays import dummy_array, zarr_array, cellmap_array from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit from .evaluators import binary_3_channel_evaluator from .losses import dummy_loss from .post_processors import argmax, threshold from .predictors import distance_predictor, onehot_predictor -from .runs import dummy_run, distance_run, onehot_run +from .runs import dummy_run, distance_run, onehot_run, cellpose_run from .tasks import dummy_task, distance_task, onehot_task from .trainers import dummy_trainer, gunpowder_trainer diff --git a/tests/fixtures/architectures.py b/tests/fixtures/architectures.py index 6980c8f6b..0c67ae15d 100644 --- a/tests/fixtures/architectures.py +++ b/tests/fixtures/architectures.py @@ -1,4 +1,4 @@ -from dacapo.experiments.architectures import DummyArchitectureConfig +from dacapo.experiments.architectures import DummyArchitectureConfig, CellposUNetConfig import pytest @@ -8,3 +8,15 @@ def dummy_architecture(): yield DummyArchitectureConfig( name="dummy_architecture", num_in_channels=1, num_out_channels=12 ) + + +@pytest.fixture +def cellpose_architecture(): + yield CellposUNetConfig( + name="cellpose_architecture", + input_shape=(216, 216, 216), + nbase=[1, 12, 24, 48, 96], + nout=12, + conv_3D=True + # nbase=[1, 32, 64, 128, 256], nout = 32, conv_3D = True + ) diff --git a/tests/fixtures/runs.py b/tests/fixtures/runs.py index 99c4d3269..16b00d746 100644 --- a/tests/fixtures/runs.py +++ b/tests/fixtures/runs.py @@ -55,3 +55,21 @@ def onehot_run( repetition=0, num_iterations=100, ) + + +@pytest.fixture() +def cellpose_run( + dummy_datasplit, + cellpose_architecture, + dummy_task, + dummy_trainer, +): + yield RunConfig( + name="cellpose_run", + task_config=dummy_task, + architecture_config=cellpose_architecture, + trainer_config=dummy_trainer, + datasplit_config=dummy_datasplit, + repetition=0, + num_iterations=100, + ) diff --git a/tests/operations/test_apply.py b/tests/operations/test_apply.py index 5ce608e1e..02bbd47bf 100644 --- a/tests/operations/test_apply.py +++ b/tests/operations/test_apply.py @@ -8,7 +8,7 @@ from dacapo import apply import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf import logging @@ -18,9 +18,9 @@ @pytest.mark.parametrize( "run_config", [ - # lazy_fixture("distance_run"), - lazy_fixture("dummy_run"), - # lazy_fixture("onehot_run"), + # lf("distance_run"), + lf("dummy_run"), + # lf("onehot_run"), ], ) def test_apply(options, run_config, zarr_array, tmp_path): diff --git a/tests/operations/test_cellpose.py b/tests/operations/test_cellpose.py new file mode 100644 index 000000000..e55cc3321 --- /dev/null +++ b/tests/operations/test_cellpose.py @@ -0,0 +1,61 @@ +import numpy as np +from dacapo.store.create_store import create_stats_store +from ..fixtures import * + +from dacapo.experiments import Run +from dacapo.store.create_store import create_config_store, create_weights_store +from dacapo.train import train_run +from pytest_lazy_fixtures import lf +import pytest + +import logging + +logging.basicConfig(level=logging.INFO) + + +# skip the test for the Apple Paravirtual device +# that does not support Metal 2.0 +@pytest.mark.filterwarnings("ignore:.*Metal 2.0.*:UserWarning") +@pytest.mark.parametrize( + "run_config", + [ + lf("cellpose_run"), + ], +) +def test_train( + run_config, +): + print("Test train") + # create a store + + store = create_config_store() + stats_store = create_stats_store() + weights_store = create_weights_store() + + # store the configs + + store.store_run_config(run_config) + run = Run(run_config) + print("Run created ") + print(run.model) + + # # ------------------------------------- + + # # train + + # weights_store.store_weights(run, 0) + # print("Weights stored") + # train_run(run) + + # init_weights = weights_store.retrieve_weights(run.name, 0) + # final_weights = weights_store.retrieve_weights(run.name, run.train_until) + + # for name, weight in init_weights.model.items(): + # weight_diff = (weight - final_weights.model[name]).sum() + # assert abs(weight_diff) > np.finfo(weight_diff.numpy().dtype).eps, weight_diff + + # # assert train_stats and validation_scores are available + + # training_stats = stats_store.retrieve_training_stats(run_config.name) + + # assert training_stats.trained_until() == run_config.num_iterations diff --git a/tests/operations/test_predict.py b/tests/operations/test_predict.py index cd8f6a6c1..4a4b53478 100644 --- a/tests/operations/test_predict.py +++ b/tests/operations/test_predict.py @@ -8,7 +8,7 @@ from dacapo import predict import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf import logging @@ -18,9 +18,9 @@ @pytest.mark.parametrize( "run_config", [ - # lazy_fixture("distance_run"), - lazy_fixture("dummy_run"), - # lazy_fixture("onehot_run"), + # lf("distance_run"), + lf("dummy_run"), + # lf("onehot_run"), ], ) def test_predict(options, run_config, zarr_array, tmp_path): diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index d36655ea7..a852101be 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -7,7 +7,7 @@ from dacapo.train import train_run import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf import logging @@ -20,9 +20,9 @@ @pytest.mark.parametrize( "run_config", [ - lazy_fixture("distance_run"), - lazy_fixture("dummy_run"), - lazy_fixture("onehot_run"), + lf("distance_run"), + lf("dummy_run"), + lf("onehot_run"), ], ) def test_train( diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index fa2cc6b9a..b293bcbf7 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -8,7 +8,7 @@ from dacapo import validate import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf import logging @@ -18,8 +18,8 @@ @pytest.mark.parametrize( "run_config", [ - lazy_fixture("distance_run"), - # lazy_fixture("onehot_run"), + lf("distance_run"), + # lf("onehot_run"), ], ) def test_validate(