From 48de078c6e49e5fb3388b42b578e8f7a28c745c4 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 13 Feb 2024 07:00:42 -0500 Subject: [PATCH] =?UTF-8?q?feat:=20=E2=9C=A8=20Add=20predict=20functionali?= =?UTF-8?q?ty=20to=20dacapo=20CLI=20module?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/__init__.py | 1 + dacapo/apply.py | 2 +- dacapo/blockwise/predict_worker.py | 32 +++---- dacapo/cli.py | 142 +++++++++++++++++++++++++++-- dacapo/compute_context/bsub.py | 2 +- 5 files changed, 155 insertions(+), 24 deletions(-) diff --git a/dacapo/__init__.py b/dacapo/__init__.py index db3a662ed..45ce3a835 100644 --- a/dacapo/__init__.py +++ b/dacapo/__init__.py @@ -3,3 +3,4 @@ from .apply import apply # noqa from .train import train # noqa from .validate import validate # noqa +from .predict import predict # noqa diff --git a/dacapo/apply.py b/dacapo/apply.py index ecaedb62e..696832cd6 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -35,7 +35,7 @@ def apply( parameters: Optional[PostProcessorParameters or str] = None, roi: Optional[Roi or str] = None, num_cpu_workers: int = 30, - output_dtype: Optional[np.dtype or str] = np.uint8, # type: ignore + output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore compute_context: ComputeContext = LocalTorch(), overwrite: bool = True, file_format: str = "zarr", diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index c8e62da2d..321c54d00 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -237,22 +237,22 @@ def spawn_worker( """ # Make the command for the worker to run command = [ # TODO - "dacapo", - "predict", - "--name", - model.name, - "--criterion", - "best", - "--channels", - "0", - "--out_container", - prediction_array_identifier.container, - "--out_dataset", - prediction_array_identifier.dataset, - "--in_container", - raw_array.container, - "--in_dataset", - raw_array.dataset, + # "dacapo", + # "predict", + # "--name", + # model.name, + # "--criterion", + # "best", + # "--channels", + # "0", + # "--out_container", + # prediction_array_identifier.container, + # "--out_dataset", + # prediction_array_identifier.dataset, + # "--in_container", + # raw_array.container, + # "--in_dataset", + # raw_array.dataset, ] def run_worker(): diff --git a/dacapo/cli.py b/dacapo/cli.py index f97906508..d8b0f3074 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -1,8 +1,22 @@ +from pathlib import Path from typing import Optional +import numpy as np + import dacapo import click import logging +from daisy import Roi, Coordinate +from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray +from dacapo.experiments.datasplits.datasets.dataset import Dataset +from dacapo.experiments.run import Run +from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( + PostProcessorParameters, +) +from dacapo.store.array_store import LocalArrayIdentifier +from dacapo.store.create_store import create_config_store, create_weights_store +from dacapo import compute_context +from dacapo.compute_context import ComputeContext, LocalTorch @click.group() @@ -65,19 +79,26 @@ def validate(run_name, iteration): ) @click.option("-w", "--num_cpu_workers", type=int, default=30) @click.option("-dt", "--output_dtype", type=str, default="uint8") +@click.option("-ow", "--overwrite", is_flag=True) +@click.option("-cc", "--compute_context", type=str, default="LocalTorch") def apply( run_name: str, - input_container: str, + input_container: Path or str, input_dataset: str, - output_path: str, - validation_dataset: Optional[str] = None, + output_path: Path or str, + validation_dataset: Optional[Dataset or str] = None, criterion: Optional[str] = "voi", iteration: Optional[int] = None, - parameters: Optional[str] = None, - roi: Optional[str] = None, + parameters: Optional[PostProcessorParameters or str] = None, + roi: Optional[Roi or str] = None, num_cpu_workers: int = 30, - output_dtype: Optional[str] = "uint8", + output_dtype: Optional[np.dtype | str] = "uint8", + overwrite: bool = True, + compute_context: Optional[ComputeContext | str] = LocalTorch(), ): + if isinstance(compute_context, str): + compute_context = getattr(compute_context, compute_context)() + dacapo.apply( run_name, input_container, @@ -90,4 +111,113 @@ def apply( roi, num_cpu_workers, output_dtype, + overwrite=overwrite, + compute_context=compute_context, # type: ignore + ) + + +@cli.command() +@click.option( + "-r", "--run-name", required=True, type=str, help="The name of the run to apply." +) +@click.option( + "-i", + "--iteration", + required=True, + type=int, + help="The training iteration of the model to use for prediction.", +) +@click.option( + "-ic", + "--input_container", + required=True, + type=click.Path(exists=True, file_okay=False), +) +@click.option("-id", "--input_dataset", required=True, type=str) +@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False)) +@click.option( + "-roi", + "--roi", + type=str, + required=False, + help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]", +) +@click.option("-w", "--num_cpu_workers", type=int, default=30) +@click.option("-dt", "--output_dtype", type=str, default="uint8") +@click.option( + "-cc", + "--compute_context", + type=str, + default="LocalTorch", + help="The compute context to use for prediction. Must be the name of a subclass of ComputeContext.", +) +@click.option("-ow", "--overwrite", is_flag=True) +def predict( + run_name: str, + iteration: int, + input_container: Path or str, + input_dataset: str, + output_path: Path or str, + roi: Optional[str | Roi] = None, + num_cpu_workers: int = 30, + output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore + compute_context: Optional[ComputeContext | str] = LocalTorch(), + overwrite: bool = True, +): + # retrieving run + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + + # create weights store + weights_store = create_weights_store() + + # load weights + weights_store.retrieve_weights(run_name, iteration) + + # get arrays + input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) + input_array = ZarrArray.open_from_array_identifier(input_array_identifier) + output_container = Path( + output_path, + "".join(Path(input_container).name.split(".")[:-1]) + f".zarr", + ) # TODO: zarr hardcoded + prediction_array_identifier = LocalArrayIdentifier( + output_container, f"prediction_{run_name}_{iteration}" + ) + + if isinstance(roi, str): + start, end = zip( + *[ + tuple(int(coord) for coord in axis.split(":")) + for axis in roi.strip("[]").split(",") + ] + ) + roi = Roi( + Coordinate(start), + Coordinate(end) - Coordinate(start), + ) + + if roi is None: + roi = input_array.roi + else: + roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect( + input_array.roi + ) + + if isinstance(output_dtype, str): + output_dtype = np.dtype(output_dtype) + + if isinstance(compute_context, str): + compute_context = getattr(compute_context, compute_context)() + + dacapo.predict( + run.model, + input_array, + prediction_array_identifier, + output_roi=roi, + num_cpu_workers=num_cpu_workers, + output_dtype=output_dtype, + compute_context=compute_context, # type: ignore + overwrite=overwrite, ) diff --git a/dacapo/compute_context/bsub.py b/dacapo/compute_context/bsub.py index 89422bf6c..a3b46f7cc 100644 --- a/dacapo/compute_context/bsub.py +++ b/dacapo/compute_context/bsub.py @@ -7,7 +7,7 @@ @attr.s -class Bsub(ComputeContext): +class Bsub(ComputeContext): # TODO: Load defaults from dacapo.yaml queue: str = attr.ib(default="local", metadata={"help_text": "The queue to run on"}) num_gpus: int = attr.ib( default=1,