diff --git a/dacapo/apply.py b/dacapo/apply.py index 696832cd6..a701d9272 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -192,7 +192,7 @@ def apply_run( input_array, prediction_array_identifier, output_roi=roi, - num_cpu_workers=num_cpu_workers, + num_workers=num_cpu_workers, output_dtype=output_dtype, compute_context=compute_context, overwrite=overwrite, diff --git a/dacapo/blockwise/__init__.py b/dacapo/blockwise/__init__.py index 4c7bc0c13..876db03d0 100644 --- a/dacapo/blockwise/__init__.py +++ b/dacapo/blockwise/__init__.py @@ -1 +1,2 @@ from .blockwise_task import DaCapoBlockwiseTask +from .scheduler import run_blockwise diff --git a/dacapo/blockwise/blockwise_task.py b/dacapo/blockwise/blockwise_task.py index 01d1a0c15..a30eecec8 100644 --- a/dacapo/blockwise/blockwise_task.py +++ b/dacapo/blockwise/blockwise_task.py @@ -4,13 +4,14 @@ from typing import Callable, Optional from daisy import Task, Roi from dacapo.compute_context import ComputeContext, LocalTorch, Bsub +import dacapo.compute_context class DaCapoBlockwiseTask(Task): def __init__( self, worker_file: str or Path, - compute_context: ComputeContext, + compute_context: ComputeContext or str, total_roi: Roi, read_roi: Roi, write_roi: Roi, @@ -21,6 +22,9 @@ def __init__( *args, **kwargs ): + if isinstance(compute_context, str): + compute_context = getattr(dacapo.compute_context, compute_context)() + # Make the task_id unique timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") task_id = worker_file + timestamp diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 321c54d00..39e29e12e 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -1,14 +1,19 @@ +from pathlib import Path import subprocess from typing import Optional import dacapo +from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from dacapo.experiments.model import Model +from dacapo.gp.dacapo_array_source import DaCapoArraySource from dacapo.store.array_store import LocalArrayIdentifier from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.experiments import Run from dacapo.compute_context import ComputeContext, LocalTorch, Bsub +import gunpowder as gp +import gunpowder.torch as gp_torch import daisy -from daisy import Roi +from daisy import Roi, Coordinate from funlib.persistence import open_ds, Array from skimage.transform import rescale # TODO @@ -38,190 +43,144 @@ def cli(log_level): @cli.command() -@click.option("-n", "--name", type=str) -@click.option("-c", "--criterion", type=str) -@click.option("-cs", "--channels", type=str) -@click.option("-oc", "--out_container", type=click.Path(exists=True, file_okay=False)) -@click.option("-od", "--out_dataset", type=str) -@click.option("-ic", "--in_container", type=click.Path(exists=True, file_okay=False)) -@click.option("-id", "--in_dataset", type=str) -@click.option("--min-raw", type=float, default=0) -@click.option("--max-raw", type=float, default=255) @click.option( - "-mc", - "--mask-container", - type=click.Path(file_okay=False), - multiple=True, - default=list(), + "-r", "--run-name", required=True, type=str, help="The name of the run to apply." ) @click.option( - "-md", - "--mask-dataset", - type=click.Path(file_okay=False), - multiple=True, - default=list(), + "-i", + "--iteration", + required=True, + type=int, + help="The training iteration of the model to use for prediction.", ) @click.option( - "--instance", - type=bool, - default=False, + "-ic", + "--input_container", + required=True, + type=click.Path(exists=True, file_okay=False), ) +@click.option("-id", "--input_dataset", required=True, type=str) +@click.option( + "-oc", "--output_container", required=True, type=click.Path(file_okay=False) +) +@click.option("-od", "--output_dataset", required=True, type=str) +@click.option("-d", "--device", type=str, default="cuda") def start_worker( - name, - criterion, - channels, - out_container, - out_dataset, - in_container, - in_dataset, - min_raw, - max_raw, - mask_container, - mask_dataset, - instance, + run_name: str, + iteration: int, + input_container: Path or str, + input_dataset: str, + output_container: Path or str, + output_dataset: str, + device: str = "cuda", ): - shift = min_raw - scale = max_raw - min_raw - parsed_channels = [channel.split(":") for channel in channels.split(",")] - - device = torch.device("cuda") - - client = daisy.Client() - + # retrieving run config_store = create_config_store() - weights_store = create_weights_store() - - run_config = config_store.retrieve_run_config(name) + run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - model = run.model - try: - weights_store._load_best(run, criterion) - except FileNotFoundError: - iteration = int(criterion) - weights = weights_store.retrieve_weights(run, iteration) - model.load_state_dict(weights.model) - - model = run.model.to(device) - - raw_dataset = open_ds(in_container, in_dataset) - mask_datasets = [open_ds(mc, md) for mc, md in zip(mask_container, mask_dataset)] - - voxel_size = raw_dataset.voxel_size - output_voxel_size = model.scale(voxel_size) + # create weights store + weights_store = create_weights_store() - if not instance: - out_datasets = [ - open_ds( - out_container, - f"{out_dataset}/{channel}", - mode="r+", + # load weights + weights_store.retrieve_weights(run_name, iteration) + + # get arrays + raw_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) + raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier) + + output_array_identifier = LocalArrayIdentifier( + Path(output_container), output_dataset + ) + output_array = ZarrArray.open_from_array_identifier(output_array_identifier) + + # get the model's input and output size + model = run.model.eval() + input_voxel_size = Coordinate(raw_array.voxel_size) + output_voxel_size = model.scale(input_voxel_size) + input_shape = Coordinate(model.eval_input_shape) + input_size = input_voxel_size * input_shape + output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] + + logger.info( + "Predicting with input size %s, output size %s", input_size, output_size + ) + # create gunpowder keys + + raw = gp.ArrayKey("RAW") + prediction = gp.ArrayKey("PREDICTION") + + # assemble prediction pipeline + + # prepare data source + pipeline = DaCapoArraySource(raw_array, raw) + # raw: (c, d, h, w) + pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) + # raw: (c, d, h, w) + pipeline += gp.Unsqueeze([raw]) + # raw: (1, c, d, h, w) + + # predict + pipeline += gp_torch.Predict( + model=model, + inputs={"x": raw}, + outputs={0: prediction}, + array_specs={ + prediction: gp.ArraySpec( + voxel_size=output_voxel_size, + dtype=np.float32, # assumes network output is float32 ) - for _, channel in parsed_channels - ] - else: - num_channels = model.num_out_channels - assert len(parsed_channels) == 1 - indexes, channel = parsed_channels[0] - out_datasets = [ - open_ds(out_container, f"{out_dataset}/{channel}__{i}", mode="r+") - for i in range(0, num_channels, 3) - ] + }, + spawn_subprocess=False, + device=device, # type: ignore + ) + # raw: (1, c, d, h, w) + # prediction: (1, [c,] d, h, w) + + # prepare writing + pipeline += gp.Squeeze([raw, prediction]) + # raw: (c, d, h, w) + # prediction: (c, d, h, w) + + # convert to uint8 if necessary: + if output_array.dtype == np.uint8: + pipeline += gp.IntensityScaleShift( + prediction, scale=255.0, shift=0.0 + ) # assumes float32 is [0,1] + pipeline += gp.AsType(prediction, output_array.dtype) + + # wait for blocks to run pipeline + client = daisy.Client() while True: + print("getting block") with client.acquire_block() as block: if block is None: break - if len(mask_datasets) > 0: - mask_data = any( - [ - np.any( - mask_dataset.to_ndarray( - roi=block.read_roi.snap_to_grid( - mask_dataset.voxel_size - ), - fill_value=0, - ) - ) - for mask_dataset in mask_datasets - ] - ) - else: - mask_data = 1 - if not np.any(mask_data): - # avoid predicting if mask is empty - continue - - if not instance: - raw_input = ( - 2.0 - * ( - raw_dataset.to_ndarray( - roi=block.read_roi, fill_value=shift + scale - ).astype(np.float32) - - shift - ) - / scale - ) - 1.0 - else: - raw_input = ( - raw_dataset.to_ndarray( - roi=block.read_roi, fill_value=shift + scale - ).astype(np.float32) - - shift - ) / scale - raw_input = np.expand_dims(raw_input, (0, 1)) - write_roi = block.write_roi.intersect(out_datasets[0].roi) - - if out_datasets[0].to_ndarray(write_roi).any(): - # block has already been processed - continue - - with torch.no_grad(): - predictions = Array( - model.forward(torch.from_numpy(raw_input).float().to(device)) - .detach() - .cpu() - .numpy()[0], - block.write_roi, - output_voxel_size, - ) - - write_data = predictions.to_ndarray(write_roi).clip(-1, 1) - if not instance: - write_data = (write_data + 1) * 255.0 / 2.0 - for (i, _), out_dataset in zip(parsed_channels, out_datasets): - indexes = [] - if "-" in i: - indexes = [int(j) for j in i.split("-")] - else: - indexes = [int(i)] - if len(indexes) > 1: - out_dataset[write_roi] = np.stack( - [write_data[j] for j in indexes], axis=0 - ).astype(np.uint8) - else: - out_dataset[write_roi] = write_data[indexes[0]].astype( - np.uint8 - ) - else: - for i, out_dataset in zip(range(0, num_channels, 3), out_datasets): - out_dataset[write_roi] = write_data[i : i + 3].astype( - np.float32 - ) - - block.status = daisy.BlockStatus.SUCCESS + + ref_request = gp.BatchRequest() + ref_request[raw] = gp.ArraySpec( + roi=block.read_roi, voxel_size=input_voxel_size, dtype=raw_array.dtype + ) + ref_request[prediction] = gp.ArraySpec( + roi=block.write_roi, + voxel_size=output_voxel_size, + dtype=output_array.dtype, + ) + + with gp.build(pipeline): + batch = pipeline.request_batch(ref_request) + + # write to output array + output_array[block.write_roi] = batch.arrays[prediction].data def spawn_worker( - model: Model, - raw_array: Array, + run_name: str, + iteration: int, + raw_array_identifier: LocalArrayIdentifier, prediction_array_identifier: LocalArrayIdentifier, - num_cpu_workers: int = 4, compute_context: ComputeContext = LocalTorch(), - output_roi: Optional[Roi] = None, - output_dtype: np.dtype = np.float32, # type: ignore - overwrite: bool = False, ): """Spawn a worker to predict on a given dataset. @@ -229,30 +188,27 @@ def spawn_worker( model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. - num_cpu_workers (int, optional): The number of CPU workers to use. Defaults to 4. compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch(). - output_roi (Optional[Roi], optional): The ROI to predict on. Defaults to None. - output_dtype (np.dtype, optional): The dtype of the output. Defaults to np.float32. - overwrite (bool, optional): Whether to overwrite the prediction array. Defaults to False. """ # Make the command for the worker to run command = [ # TODO - # "dacapo", - # "predict", - # "--name", - # model.name, - # "--criterion", - # "best", - # "--channels", - # "0", - # "--out_container", - # prediction_array_identifier.container, - # "--out_dataset", - # prediction_array_identifier.dataset, - # "--in_container", - # raw_array.container, - # "--in_dataset", - # raw_array.dataset, + "python", + __file__, + "start-worker", + "--run-name", + run_name, + "--iteration", + iteration, + "--input_container", + raw_array_identifier.container, + "--input_dataset", + raw_array_identifier.dataset, + "--output_container", + prediction_array_identifier.container, + "--output_dataset", + prediction_array_identifier.dataset, + "--device", + str(compute_context.device), ] def run_worker(): diff --git a/dacapo/blockwise/scheduler.py b/dacapo/blockwise/scheduler.py index aaa051d6b..7ebebdf2a 100644 --- a/dacapo/blockwise/scheduler.py +++ b/dacapo/blockwise/scheduler.py @@ -6,9 +6,9 @@ from dacapo.blockwise import DaCapoBlockwiseTask -def blockwise( +def run_blockwise( worker_file: str or Path, - compute_context: ComputeContext, + compute_context: ComputeContext | str, total_roi: Roi, read_roi: Roi, write_roi: Roi, diff --git a/dacapo/cli.py b/dacapo/cli.py index d8b0f3074..bab2f6963 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -137,12 +137,12 @@ def apply( @click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False)) @click.option( "-roi", - "--roi", + "--output_roi", type=str, required=False, help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]", ) -@click.option("-w", "--num_cpu_workers", type=int, default=30) +@click.option("-w", "--num_workers", type=int, default=30) @click.option("-dt", "--output_dtype", type=str, default="uint8") @click.option( "-cc", @@ -158,66 +158,21 @@ def predict( input_container: Path or str, input_dataset: str, output_path: Path or str, - roi: Optional[str | Roi] = None, - num_cpu_workers: int = 30, - output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore - compute_context: Optional[ComputeContext | str] = LocalTorch(), + output_roi: Optional[str | Roi] = None, + num_workers: int = 30, + output_dtype: np.dtype | str = np.uint8, # type: ignore + compute_context: ComputeContext | str = LocalTorch(), overwrite: bool = True, ): - # retrieving run - config_store = create_config_store() - run_config = config_store.retrieve_run_config(run_name) - run = Run(run_config) - - # create weights store - weights_store = create_weights_store() - - # load weights - weights_store.retrieve_weights(run_name, iteration) - - # get arrays - input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) - input_array = ZarrArray.open_from_array_identifier(input_array_identifier) - output_container = Path( - output_path, - "".join(Path(input_container).name.split(".")[:-1]) + f".zarr", - ) # TODO: zarr hardcoded - prediction_array_identifier = LocalArrayIdentifier( - output_container, f"prediction_{run_name}_{iteration}" - ) - - if isinstance(roi, str): - start, end = zip( - *[ - tuple(int(coord) for coord in axis.split(":")) - for axis in roi.strip("[]").split(",") - ] - ) - roi = Roi( - Coordinate(start), - Coordinate(end) - Coordinate(start), - ) - - if roi is None: - roi = input_array.roi - else: - roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect( - input_array.roi - ) - - if isinstance(output_dtype, str): - output_dtype = np.dtype(output_dtype) - - if isinstance(compute_context, str): - compute_context = getattr(compute_context, compute_context)() - dacapo.predict( - run.model, - input_array, - prediction_array_identifier, - output_roi=roi, - num_cpu_workers=num_cpu_workers, - output_dtype=output_dtype, - compute_context=compute_context, # type: ignore - overwrite=overwrite, + run_name, + iteration, + input_container, + input_dataset, + output_path, + output_roi, + num_workers, + output_dtype, + compute_context, + overwrite, ) diff --git a/dacapo/compute_context/bsub.py b/dacapo/compute_context/bsub.py index a3b46f7cc..af2befa80 100644 --- a/dacapo/compute_context/bsub.py +++ b/dacapo/compute_context/bsub.py @@ -27,7 +27,10 @@ class Bsub(ComputeContext): # TODO: Load defaults from dacapo.yaml @property def device(self): - return None + if self.num_gpus > 0: + return "cuda" + else: + return "cpu" def wrap_command(self, command): return ( diff --git a/dacapo/compute_context/compute_context.py b/dacapo/compute_context/compute_context.py index b0c38de00..19e2ad895 100644 --- a/dacapo/compute_context/compute_context.py +++ b/dacapo/compute_context/compute_context.py @@ -8,14 +8,6 @@ class ComputeContext(ABC): def device(self): pass - def train(self, run_name): - # A helper method to run train in some other context. - # This can be on a cluster, in a cloud, through bsub, - # etc. - # If training should be done locally, return False, - # else return True. - return False - def wrap_command(self, command): # A helper method to wrap a command in the context # specific command. diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 25f2c224e..dc24230d6 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -116,6 +116,7 @@ def create_from_array_identifier( dtype, write_size=None, name=None, + overwrite=False, ): """ Create a new ZarrArray given an array identifier. It is assumed that @@ -145,6 +146,7 @@ def create_from_array_identifier( dtype, num_channels=num_channels, write_size=write_size, + delete=overwrite, ) zarr_dataset = zarr_container[array_identifier.dataset] zarr_dataset.attrs["offset"] = ( diff --git a/dacapo/predict.py b/dacapo/predict.py index a98ab1675..958a606f8 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -1,9 +1,15 @@ +from pathlib import Path + +import click +from dacapo.blockwise import run_blockwise +from dacapo.experiments import Run from dacapo.gp import DaCapoArraySource -from dacapo.experiments.model import Model -from dacapo.experiments.datasplits.datasets.arrays import Array +from dacapo.experiments import Model +from dacapo.store import create_config_store +from dacapo.store import create_weights_store from dacapo.store.local_array_store import LocalArrayIdentifier from dacapo.compute_context import LocalTorch, ComputeContext -from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray +from dacapo.experiments.datasplits.datasets.arrays import ZarrArray, Array from funlib.geometry import Coordinate, Roi import gunpowder as gp @@ -17,16 +23,94 @@ logger = logging.getLogger(__name__) -def predict( # TODO: MAKE THIS CLI ACCESSIBLE - model: Model, - raw_array: Array, - prediction_array_identifier: LocalArrayIdentifier, - num_cpu_workers: int = 4, - compute_context: ComputeContext = LocalTorch(), - output_roi: Optional[Roi] = None, - output_dtype: Optional[np.dtype] = np.uint8, # type: ignore - overwrite: bool = False, +@cli.command() +@click.option( + "-r", "--run-name", required=True, type=str, help="The name of the run to apply." +) +@click.option( + "-i", + "--iteration", + required=True, + type=int, + help="The training iteration of the model to use for prediction.", +) +@click.option( + "-ic", + "--input_container", + required=True, + type=click.Path(exists=True, file_okay=False), +) +@click.option("-id", "--input_dataset", required=True, type=str) +@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False)) +@click.option( + "-roi", + "--output_roi", + type=str, + required=False, + help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]", +) +@click.option("-w", "--num_workers", type=int, default=30) +@click.option("-dt", "--output_dtype", type=str, default="uint8") +@click.option( + "-cc", + "--compute_context", + type=str, + default="LocalTorch", + help="The compute context to use for prediction. Must be the name of a subclass of ComputeContext.", +) +@click.option("-ow", "--overwrite", is_flag=True) +def predict( + run_name: str, + iteration: int, + input_container: Path or str, + input_dataset: str, + output_path: Path or str, + output_roi: Optional[str | Roi] = None, + num_workers: int = 30, + output_dtype: np.dtype | str = np.uint8, # type: ignore + compute_context: ComputeContext | str = LocalTorch(), + overwrite: bool = True, ): + # retrieving run + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + + # get arrays + raw_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) + raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier) + output_container = Path( + output_path, + "".join(Path(input_container).name.split(".")[:-1]) + f".zarr", + ) # TODO: zarr hardcoded + prediction_array_identifier = LocalArrayIdentifier( + output_container, f"prediction_{run_name}_{iteration}" + ) + + if isinstance(output_roi, str): + start, end = zip( + *[ + tuple(int(coord) for coord in axis.split(":")) + for axis in output_roi.strip("[]").split(",") + ] + ) + output_roi = Roi( + Coordinate(start), + Coordinate(end) - Coordinate(start), + ) + + if output_roi is None: + output_roi = raw_array.roi + else: + output_roi = output_roi.snap_to_grid( + raw_array.voxel_size, mode="grow" + ).intersect(raw_array.roi) + + if isinstance(output_dtype, str): + output_dtype = np.dtype(output_dtype) + + model = run.model.eval() + # get the model's input and output size input_voxel_size = Coordinate(raw_array.voxel_size) @@ -59,77 +143,26 @@ def predict( # TODO: MAKE THIS CLI ACCESSIBLE model.num_out_channels, output_voxel_size, output_dtype, + overwrite=overwrite, ) - # create gunpowder keys - - raw = gp.ArrayKey("RAW") - prediction = gp.ArrayKey("PREDICTION") - - # assemble prediction pipeline - - # prepare data source - pipeline = DaCapoArraySource(raw_array, raw) - # raw: (c, d, h, w) - pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) - # raw: (c, d, h, w) - pipeline += gp.Unsqueeze([raw]) - # raw: (1, c, d, h, w) - - gt_padding = (output_size - output_roi.shape) % output_size - prediction_roi = output_roi.grow(gt_padding) # TODO: are we sure this makes sense? - # TODO: Add cache node? - # predict - pipeline += gp_torch.Predict( - model=model, - inputs={"x": raw}, - outputs={0: prediction}, - array_specs={ - prediction: gp.ArraySpec( - roi=prediction_roi, - voxel_size=output_voxel_size, - dtype=np.float32, # assumes network output is float32 - ) - }, - spawn_subprocess=False, - device=str(compute_context.device), - ) - # raw: (1, c, d, h, w) - # prediction: (1, [c,] d, h, w) - - # prepare writing - pipeline += gp.Squeeze([raw, prediction]) - # raw: (c, d, h, w) - # prediction: (c, d, h, w) - - # convert to uint8 if necessary: - if output_dtype == np.uint8: - pipeline += gp.IntensityScaleShift( - prediction, scale=255.0, shift=0.0 - ) # assumes float32 is [0,1] - pipeline += gp.AsType(prediction, output_dtype) - - # write to zarr - pipeline += gp.ZarrWrite( - {prediction: prediction_array_identifier.dataset}, - str(prediction_array_identifier.container.parent), - prediction_array_identifier.container.name, - dataset_dtypes={prediction: output_dtype}, + # run blockwise prediction + run_blockwise( + worker_file=str(Path(Path(__file__).parent, "blockwise", "predict_worker.py")), + compute_context=compute_context, + total_roi=output_roi, + read_roi=input_roi, + write_roi=output_roi, + num_workers=num_workers, + max_retries=2, # TODO: make this an option + timeout=None, # TODO: make this an option + ###### + run_name=run_name, + iteration=iteration, + raw_array_identifier=raw_array_identifier, + prediction_array_identifier=prediction_array_identifier, ) - # create reference batch request - ref_request = gp.BatchRequest() - ref_request.add(raw, input_size) - ref_request.add(prediction, output_size) - pipeline += gp.Scan( - ref_request - ) # TODO: This is a slow implementation for rendering - - # build pipeline and predict in complete output ROI - - with gp.build(pipeline): - pipeline.request_batch(gp.BatchRequest()) - container = zarr.open(str(prediction_array_identifier.container)) dataset = container[prediction_array_identifier.dataset] dataset.attrs["axes"] = ( # type: ignore