diff --git a/dacapo/experiments/architectures/__init__.py b/dacapo/experiments/architectures/__init__.py index 6125893c1..f21fe05f1 100644 --- a/dacapo/experiments/architectures/__init__.py +++ b/dacapo/experiments/architectures/__init__.py @@ -5,3 +5,4 @@ DummyArchitecture, ) # noqa from .cnnectome_unet_config import CNNectomeUNetConfig, CNNectomeUNet # noqa +from .cellpose_unet_config import CellposUNetConfig, CellposeUnet # noqa diff --git a/dacapo/experiments/architectures/cellpose_unet.py b/dacapo/experiments/architectures/cellpose_unet.py new file mode 100644 index 000000000..cbdf707ad --- /dev/null +++ b/dacapo/experiments/architectures/cellpose_unet.py @@ -0,0 +1,75 @@ +from cellpose.resnet_torch import CPnet +from .architecture import Architecture +from funlib.geometry import Coordinate + + +# example +# nout = 4 +# sz = 3 +# self.net = CPnet( +# nbase, nout, sz, mkldnn=False, conv_3D=True, max_pool=True, diam_mean=30.0 +# ) +# currently the input channels are embedded in nbdase, but they should be passed as a separate parameternbase = [in_chan, 32, 64, 128, 256] +class CellposeUnet(Architecture): + def __init__(self, architecture_config): + super().__init__() + self._input_shape = Coordinate(architecture_config.input_shape) + self._nbase = architecture_config.nbase + self._sz = self._input_shape.dims + self._eval_shape_increase = Coordinate((0,) * self._sz) + self._nout = architecture_config.nout + print("conv_3D:", architecture_config.conv_3D) + self.unet = CPnet( + architecture_config.nbase, + architecture_config.nout, + self._sz, + architecture_config.mkldnn, + architecture_config.conv_3D, + architecture_config.max_pool, + architecture_config.diam_mean, + ) + print(self.unet) + + def forward(self, data): + """ + Forward pass of the CPnet model. + + Args: + data (torch.Tensor): Input data. + + Returns: + tuple: A tuple containing the output tensor, style tensor, and downsampled tensors. + """ + if self.unet.mkldnn: + data = data.to_mkldnn() + T0 = self.unet.downsample(data) + if self.unet.mkldnn: + style = self.unet.make_style(T0[-1].to_dense()) + else: + style = self.unet.make_style(T0[-1]) + # style0 = style + if not self.unet.style_on: + style = style * 0 + T1 = self.unet.upsample(style, T0, self.unet.mkldnn) + # head layer + # T1 = self.unet.output(T1) + if self.unet.mkldnn: + T0 = [t0.to_dense() for t0 in T0] + T1 = T1.to_dense() + return T1 + + @property + def input_shape(self): + return self._input_shape + + @property + def num_in_channels(self) -> int: + return self._nbase[0] + + @property + def num_out_channels(self) -> int: + return self._nout + + @property + def eval_shape_increase(self): + return self._eval_shape_increase diff --git a/dacapo/experiments/architectures/cellpose_unet_config.py b/dacapo/experiments/architectures/cellpose_unet_config.py new file mode 100644 index 000000000..63d71c83d --- /dev/null +++ b/dacapo/experiments/architectures/cellpose_unet_config.py @@ -0,0 +1,41 @@ +import attr + +from .architecture_config import ArchitectureConfig +from .cellpose_unet import CellposeUnet + +from funlib.geometry import Coordinate + +from typing import List, Optional + + +@attr.s +class CellposUNetConfig(ArchitectureConfig): + """This class configures the CellPose based on + https://github.com/MouseLand/cellpose/blob/main/cellpose/resnet_torch.py + """ + + architecture_type = CellposeUnet + + input_shape: Coordinate = attr.ib( + metadata={ + "help_text": "The shape of the data passed into the network during training." + } + ) + nbase: List[int] = attr.ib( + metadata={ + "help_text": "List of integers representing the number of channels in each layer of the downsample path." + } + ) + nout: int = attr.ib(metadata={"help_text": "Number of output channels."}) + mkldnn: Optional[bool] = attr.ib( + default=False, metadata={"help_text": "Whether to use MKL-DNN acceleration."} + ) + conv_3D: bool = attr.ib( + default=False, metadata={"help_text": "Whether to use 3D convolution."} + ) + max_pool: Optional[bool] = attr.ib( + default=True, metadata={"help_text": "Whether to use max pooling."} + ) + diam_mean: Optional[float] = attr.ib( + default=30.0, metadata={"help_text": "Mean diameter of the cells."} + ) diff --git a/pyproject.toml b/pyproject.toml index 8b32fe890..9474f996f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dependencies = [ "scipy", "upath", "boto3", + "cellpose", ] # extras diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 3ea282acc..ca5ef6365 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -1,11 +1,11 @@ from .db import options -from .architectures import dummy_architecture +from .architectures import dummy_architecture, cellpose_architecture from .arrays import dummy_array, zarr_array, cellmap_array from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit from .evaluators import binary_3_channel_evaluator from .losses import dummy_loss from .post_processors import argmax, threshold from .predictors import distance_predictor, onehot_predictor -from .runs import dummy_run, distance_run, onehot_run +from .runs import dummy_run, distance_run, onehot_run, cellpose_run from .tasks import dummy_task, distance_task, onehot_task from .trainers import dummy_trainer, gunpowder_trainer diff --git a/tests/fixtures/architectures.py b/tests/fixtures/architectures.py index 6980c8f6b..0c67ae15d 100644 --- a/tests/fixtures/architectures.py +++ b/tests/fixtures/architectures.py @@ -1,4 +1,4 @@ -from dacapo.experiments.architectures import DummyArchitectureConfig +from dacapo.experiments.architectures import DummyArchitectureConfig, CellposUNetConfig import pytest @@ -8,3 +8,15 @@ def dummy_architecture(): yield DummyArchitectureConfig( name="dummy_architecture", num_in_channels=1, num_out_channels=12 ) + + +@pytest.fixture +def cellpose_architecture(): + yield CellposUNetConfig( + name="cellpose_architecture", + input_shape=(216, 216, 216), + nbase=[1, 12, 24, 48, 96], + nout=12, + conv_3D=True + # nbase=[1, 32, 64, 128, 256], nout = 32, conv_3D = True + ) diff --git a/tests/fixtures/runs.py b/tests/fixtures/runs.py index 99c4d3269..16b00d746 100644 --- a/tests/fixtures/runs.py +++ b/tests/fixtures/runs.py @@ -55,3 +55,21 @@ def onehot_run( repetition=0, num_iterations=100, ) + + +@pytest.fixture() +def cellpose_run( + dummy_datasplit, + cellpose_architecture, + dummy_task, + dummy_trainer, +): + yield RunConfig( + name="cellpose_run", + task_config=dummy_task, + architecture_config=cellpose_architecture, + trainer_config=dummy_trainer, + datasplit_config=dummy_datasplit, + repetition=0, + num_iterations=100, + ) diff --git a/tests/operations/test_cellpose.py b/tests/operations/test_cellpose.py new file mode 100644 index 000000000..e55cc3321 --- /dev/null +++ b/tests/operations/test_cellpose.py @@ -0,0 +1,61 @@ +import numpy as np +from dacapo.store.create_store import create_stats_store +from ..fixtures import * + +from dacapo.experiments import Run +from dacapo.store.create_store import create_config_store, create_weights_store +from dacapo.train import train_run +from pytest_lazy_fixtures import lf +import pytest + +import logging + +logging.basicConfig(level=logging.INFO) + + +# skip the test for the Apple Paravirtual device +# that does not support Metal 2.0 +@pytest.mark.filterwarnings("ignore:.*Metal 2.0.*:UserWarning") +@pytest.mark.parametrize( + "run_config", + [ + lf("cellpose_run"), + ], +) +def test_train( + run_config, +): + print("Test train") + # create a store + + store = create_config_store() + stats_store = create_stats_store() + weights_store = create_weights_store() + + # store the configs + + store.store_run_config(run_config) + run = Run(run_config) + print("Run created ") + print(run.model) + + # # ------------------------------------- + + # # train + + # weights_store.store_weights(run, 0) + # print("Weights stored") + # train_run(run) + + # init_weights = weights_store.retrieve_weights(run.name, 0) + # final_weights = weights_store.retrieve_weights(run.name, run.train_until) + + # for name, weight in init_weights.model.items(): + # weight_diff = (weight - final_weights.model[name]).sum() + # assert abs(weight_diff) > np.finfo(weight_diff.numpy().dtype).eps, weight_diff + + # # assert train_stats and validation_scores are available + + # training_stats = stats_store.retrieve_training_stats(run_config.name) + + # assert training_stats.trained_until() == run_config.num_iterations