From 6820aa8aa3686525b5501d6836152b8380d4438d Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 13 Feb 2024 23:35:01 -0500 Subject: [PATCH] =?UTF-8?q?feat:=20=F0=9F=92=A5=20Add=20blockwise=20predic?= =?UTF-8?q?tion,=20callable=20from=20cli.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added ability to call dacapo.predict from command line. This implements a blockwise-parallelized operation, with specific implementation depending on the ComputeContext. This feature came with the development of the basis for the blockwise processing infrastructure within dacapo. The function signature of dacapo.predict has changed significantly. apply.py and any other functions/scripts referring to it will no longer function and will need to be converted to the new call structure. --- dacapo/apply.py | 2 +- dacapo/blockwise/__init__.py | 1 + dacapo/blockwise/blockwise_task.py | 6 +- dacapo/blockwise/predict_worker.py | 320 ++++++++---------- dacapo/blockwise/scheduler.py | 4 +- dacapo/cli.py | 77 +---- dacapo/compute_context/bsub.py | 5 +- dacapo/compute_context/compute_context.py | 8 - .../datasplits/datasets/arrays/zarr_array.py | 2 + dacapo/predict.py | 191 ++++++----- 10 files changed, 281 insertions(+), 335 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index 696832cd6..a701d9272 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -192,7 +192,7 @@ def apply_run( input_array, prediction_array_identifier, output_roi=roi, - num_cpu_workers=num_cpu_workers, + num_workers=num_cpu_workers, output_dtype=output_dtype, compute_context=compute_context, overwrite=overwrite, diff --git a/dacapo/blockwise/__init__.py b/dacapo/blockwise/__init__.py index 4c7bc0c13..876db03d0 100644 --- a/dacapo/blockwise/__init__.py +++ b/dacapo/blockwise/__init__.py @@ -1 +1,2 @@ from .blockwise_task import DaCapoBlockwiseTask +from .scheduler import run_blockwise diff --git a/dacapo/blockwise/blockwise_task.py b/dacapo/blockwise/blockwise_task.py index 01d1a0c15..a30eecec8 100644 --- a/dacapo/blockwise/blockwise_task.py +++ b/dacapo/blockwise/blockwise_task.py @@ -4,13 +4,14 @@ from typing import Callable, Optional from daisy import Task, Roi from dacapo.compute_context import ComputeContext, LocalTorch, Bsub +import dacapo.compute_context class DaCapoBlockwiseTask(Task): def __init__( self, worker_file: str or Path, - compute_context: ComputeContext, + compute_context: ComputeContext or str, total_roi: Roi, read_roi: Roi, write_roi: Roi, @@ -21,6 +22,9 @@ def __init__( *args, **kwargs ): + if isinstance(compute_context, str): + compute_context = getattr(dacapo.compute_context, compute_context)() + # Make the task_id unique timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") task_id = worker_file + timestamp diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 321c54d00..39e29e12e 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -1,14 +1,19 @@ +from pathlib import Path import subprocess from typing import Optional import dacapo +from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from dacapo.experiments.model import Model +from dacapo.gp.dacapo_array_source import DaCapoArraySource from dacapo.store.array_store import LocalArrayIdentifier from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.experiments import Run from dacapo.compute_context import ComputeContext, LocalTorch, Bsub +import gunpowder as gp +import gunpowder.torch as gp_torch import daisy -from daisy import Roi +from daisy import Roi, Coordinate from funlib.persistence import open_ds, Array from skimage.transform import rescale # TODO @@ -38,190 +43,144 @@ def cli(log_level): @cli.command() -@click.option("-n", "--name", type=str) -@click.option("-c", "--criterion", type=str) -@click.option("-cs", "--channels", type=str) -@click.option("-oc", "--out_container", type=click.Path(exists=True, file_okay=False)) -@click.option("-od", "--out_dataset", type=str) -@click.option("-ic", "--in_container", type=click.Path(exists=True, file_okay=False)) -@click.option("-id", "--in_dataset", type=str) -@click.option("--min-raw", type=float, default=0) -@click.option("--max-raw", type=float, default=255) @click.option( - "-mc", - "--mask-container", - type=click.Path(file_okay=False), - multiple=True, - default=list(), + "-r", "--run-name", required=True, type=str, help="The name of the run to apply." ) @click.option( - "-md", - "--mask-dataset", - type=click.Path(file_okay=False), - multiple=True, - default=list(), + "-i", + "--iteration", + required=True, + type=int, + help="The training iteration of the model to use for prediction.", ) @click.option( - "--instance", - type=bool, - default=False, + "-ic", + "--input_container", + required=True, + type=click.Path(exists=True, file_okay=False), ) +@click.option("-id", "--input_dataset", required=True, type=str) +@click.option( + "-oc", "--output_container", required=True, type=click.Path(file_okay=False) +) +@click.option("-od", "--output_dataset", required=True, type=str) +@click.option("-d", "--device", type=str, default="cuda") def start_worker( - name, - criterion, - channels, - out_container, - out_dataset, - in_container, - in_dataset, - min_raw, - max_raw, - mask_container, - mask_dataset, - instance, + run_name: str, + iteration: int, + input_container: Path or str, + input_dataset: str, + output_container: Path or str, + output_dataset: str, + device: str = "cuda", ): - shift = min_raw - scale = max_raw - min_raw - parsed_channels = [channel.split(":") for channel in channels.split(",")] - - device = torch.device("cuda") - - client = daisy.Client() - + # retrieving run config_store = create_config_store() - weights_store = create_weights_store() - - run_config = config_store.retrieve_run_config(name) + run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - model = run.model - try: - weights_store._load_best(run, criterion) - except FileNotFoundError: - iteration = int(criterion) - weights = weights_store.retrieve_weights(run, iteration) - model.load_state_dict(weights.model) - - model = run.model.to(device) - - raw_dataset = open_ds(in_container, in_dataset) - mask_datasets = [open_ds(mc, md) for mc, md in zip(mask_container, mask_dataset)] - - voxel_size = raw_dataset.voxel_size - output_voxel_size = model.scale(voxel_size) + # create weights store + weights_store = create_weights_store() - if not instance: - out_datasets = [ - open_ds( - out_container, - f"{out_dataset}/{channel}", - mode="r+", + # load weights + weights_store.retrieve_weights(run_name, iteration) + + # get arrays + raw_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) + raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier) + + output_array_identifier = LocalArrayIdentifier( + Path(output_container), output_dataset + ) + output_array = ZarrArray.open_from_array_identifier(output_array_identifier) + + # get the model's input and output size + model = run.model.eval() + input_voxel_size = Coordinate(raw_array.voxel_size) + output_voxel_size = model.scale(input_voxel_size) + input_shape = Coordinate(model.eval_input_shape) + input_size = input_voxel_size * input_shape + output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] + + logger.info( + "Predicting with input size %s, output size %s", input_size, output_size + ) + # create gunpowder keys + + raw = gp.ArrayKey("RAW") + prediction = gp.ArrayKey("PREDICTION") + + # assemble prediction pipeline + + # prepare data source + pipeline = DaCapoArraySource(raw_array, raw) + # raw: (c, d, h, w) + pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) + # raw: (c, d, h, w) + pipeline += gp.Unsqueeze([raw]) + # raw: (1, c, d, h, w) + + # predict + pipeline += gp_torch.Predict( + model=model, + inputs={"x": raw}, + outputs={0: prediction}, + array_specs={ + prediction: gp.ArraySpec( + voxel_size=output_voxel_size, + dtype=np.float32, # assumes network output is float32 ) - for _, channel in parsed_channels - ] - else: - num_channels = model.num_out_channels - assert len(parsed_channels) == 1 - indexes, channel = parsed_channels[0] - out_datasets = [ - open_ds(out_container, f"{out_dataset}/{channel}__{i}", mode="r+") - for i in range(0, num_channels, 3) - ] + }, + spawn_subprocess=False, + device=device, # type: ignore + ) + # raw: (1, c, d, h, w) + # prediction: (1, [c,] d, h, w) + + # prepare writing + pipeline += gp.Squeeze([raw, prediction]) + # raw: (c, d, h, w) + # prediction: (c, d, h, w) + + # convert to uint8 if necessary: + if output_array.dtype == np.uint8: + pipeline += gp.IntensityScaleShift( + prediction, scale=255.0, shift=0.0 + ) # assumes float32 is [0,1] + pipeline += gp.AsType(prediction, output_array.dtype) + + # wait for blocks to run pipeline + client = daisy.Client() while True: + print("getting block") with client.acquire_block() as block: if block is None: break - if len(mask_datasets) > 0: - mask_data = any( - [ - np.any( - mask_dataset.to_ndarray( - roi=block.read_roi.snap_to_grid( - mask_dataset.voxel_size - ), - fill_value=0, - ) - ) - for mask_dataset in mask_datasets - ] - ) - else: - mask_data = 1 - if not np.any(mask_data): - # avoid predicting if mask is empty - continue - - if not instance: - raw_input = ( - 2.0 - * ( - raw_dataset.to_ndarray( - roi=block.read_roi, fill_value=shift + scale - ).astype(np.float32) - - shift - ) - / scale - ) - 1.0 - else: - raw_input = ( - raw_dataset.to_ndarray( - roi=block.read_roi, fill_value=shift + scale - ).astype(np.float32) - - shift - ) / scale - raw_input = np.expand_dims(raw_input, (0, 1)) - write_roi = block.write_roi.intersect(out_datasets[0].roi) - - if out_datasets[0].to_ndarray(write_roi).any(): - # block has already been processed - continue - - with torch.no_grad(): - predictions = Array( - model.forward(torch.from_numpy(raw_input).float().to(device)) - .detach() - .cpu() - .numpy()[0], - block.write_roi, - output_voxel_size, - ) - - write_data = predictions.to_ndarray(write_roi).clip(-1, 1) - if not instance: - write_data = (write_data + 1) * 255.0 / 2.0 - for (i, _), out_dataset in zip(parsed_channels, out_datasets): - indexes = [] - if "-" in i: - indexes = [int(j) for j in i.split("-")] - else: - indexes = [int(i)] - if len(indexes) > 1: - out_dataset[write_roi] = np.stack( - [write_data[j] for j in indexes], axis=0 - ).astype(np.uint8) - else: - out_dataset[write_roi] = write_data[indexes[0]].astype( - np.uint8 - ) - else: - for i, out_dataset in zip(range(0, num_channels, 3), out_datasets): - out_dataset[write_roi] = write_data[i : i + 3].astype( - np.float32 - ) - - block.status = daisy.BlockStatus.SUCCESS + + ref_request = gp.BatchRequest() + ref_request[raw] = gp.ArraySpec( + roi=block.read_roi, voxel_size=input_voxel_size, dtype=raw_array.dtype + ) + ref_request[prediction] = gp.ArraySpec( + roi=block.write_roi, + voxel_size=output_voxel_size, + dtype=output_array.dtype, + ) + + with gp.build(pipeline): + batch = pipeline.request_batch(ref_request) + + # write to output array + output_array[block.write_roi] = batch.arrays[prediction].data def spawn_worker( - model: Model, - raw_array: Array, + run_name: str, + iteration: int, + raw_array_identifier: LocalArrayIdentifier, prediction_array_identifier: LocalArrayIdentifier, - num_cpu_workers: int = 4, compute_context: ComputeContext = LocalTorch(), - output_roi: Optional[Roi] = None, - output_dtype: np.dtype = np.float32, # type: ignore - overwrite: bool = False, ): """Spawn a worker to predict on a given dataset. @@ -229,30 +188,27 @@ def spawn_worker( model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. - num_cpu_workers (int, optional): The number of CPU workers to use. Defaults to 4. compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch(). - output_roi (Optional[Roi], optional): The ROI to predict on. Defaults to None. - output_dtype (np.dtype, optional): The dtype of the output. Defaults to np.float32. - overwrite (bool, optional): Whether to overwrite the prediction array. Defaults to False. """ # Make the command for the worker to run command = [ # TODO - # "dacapo", - # "predict", - # "--name", - # model.name, - # "--criterion", - # "best", - # "--channels", - # "0", - # "--out_container", - # prediction_array_identifier.container, - # "--out_dataset", - # prediction_array_identifier.dataset, - # "--in_container", - # raw_array.container, - # "--in_dataset", - # raw_array.dataset, + "python", + __file__, + "start-worker", + "--run-name", + run_name, + "--iteration", + iteration, + "--input_container", + raw_array_identifier.container, + "--input_dataset", + raw_array_identifier.dataset, + "--output_container", + prediction_array_identifier.container, + "--output_dataset", + prediction_array_identifier.dataset, + "--device", + str(compute_context.device), ] def run_worker(): diff --git a/dacapo/blockwise/scheduler.py b/dacapo/blockwise/scheduler.py index aaa051d6b..7ebebdf2a 100644 --- a/dacapo/blockwise/scheduler.py +++ b/dacapo/blockwise/scheduler.py @@ -6,9 +6,9 @@ from dacapo.blockwise import DaCapoBlockwiseTask -def blockwise( +def run_blockwise( worker_file: str or Path, - compute_context: ComputeContext, + compute_context: ComputeContext | str, total_roi: Roi, read_roi: Roi, write_roi: Roi, diff --git a/dacapo/cli.py b/dacapo/cli.py index d8b0f3074..bab2f6963 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -137,12 +137,12 @@ def apply( @click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False)) @click.option( "-roi", - "--roi", + "--output_roi", type=str, required=False, help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]", ) -@click.option("-w", "--num_cpu_workers", type=int, default=30) +@click.option("-w", "--num_workers", type=int, default=30) @click.option("-dt", "--output_dtype", type=str, default="uint8") @click.option( "-cc", @@ -158,66 +158,21 @@ def predict( input_container: Path or str, input_dataset: str, output_path: Path or str, - roi: Optional[str | Roi] = None, - num_cpu_workers: int = 30, - output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore - compute_context: Optional[ComputeContext | str] = LocalTorch(), + output_roi: Optional[str | Roi] = None, + num_workers: int = 30, + output_dtype: np.dtype | str = np.uint8, # type: ignore + compute_context: ComputeContext | str = LocalTorch(), overwrite: bool = True, ): - # retrieving run - config_store = create_config_store() - run_config = config_store.retrieve_run_config(run_name) - run = Run(run_config) - - # create weights store - weights_store = create_weights_store() - - # load weights - weights_store.retrieve_weights(run_name, iteration) - - # get arrays - input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) - input_array = ZarrArray.open_from_array_identifier(input_array_identifier) - output_container = Path( - output_path, - "".join(Path(input_container).name.split(".")[:-1]) + f".zarr", - ) # TODO: zarr hardcoded - prediction_array_identifier = LocalArrayIdentifier( - output_container, f"prediction_{run_name}_{iteration}" - ) - - if isinstance(roi, str): - start, end = zip( - *[ - tuple(int(coord) for coord in axis.split(":")) - for axis in roi.strip("[]").split(",") - ] - ) - roi = Roi( - Coordinate(start), - Coordinate(end) - Coordinate(start), - ) - - if roi is None: - roi = input_array.roi - else: - roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect( - input_array.roi - ) - - if isinstance(output_dtype, str): - output_dtype = np.dtype(output_dtype) - - if isinstance(compute_context, str): - compute_context = getattr(compute_context, compute_context)() - dacapo.predict( - run.model, - input_array, - prediction_array_identifier, - output_roi=roi, - num_cpu_workers=num_cpu_workers, - output_dtype=output_dtype, - compute_context=compute_context, # type: ignore - overwrite=overwrite, + run_name, + iteration, + input_container, + input_dataset, + output_path, + output_roi, + num_workers, + output_dtype, + compute_context, + overwrite, ) diff --git a/dacapo/compute_context/bsub.py b/dacapo/compute_context/bsub.py index a3b46f7cc..af2befa80 100644 --- a/dacapo/compute_context/bsub.py +++ b/dacapo/compute_context/bsub.py @@ -27,7 +27,10 @@ class Bsub(ComputeContext): # TODO: Load defaults from dacapo.yaml @property def device(self): - return None + if self.num_gpus > 0: + return "cuda" + else: + return "cpu" def wrap_command(self, command): return ( diff --git a/dacapo/compute_context/compute_context.py b/dacapo/compute_context/compute_context.py index b0c38de00..19e2ad895 100644 --- a/dacapo/compute_context/compute_context.py +++ b/dacapo/compute_context/compute_context.py @@ -8,14 +8,6 @@ class ComputeContext(ABC): def device(self): pass - def train(self, run_name): - # A helper method to run train in some other context. - # This can be on a cluster, in a cloud, through bsub, - # etc. - # If training should be done locally, return False, - # else return True. - return False - def wrap_command(self, command): # A helper method to wrap a command in the context # specific command. diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 25f2c224e..dc24230d6 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -116,6 +116,7 @@ def create_from_array_identifier( dtype, write_size=None, name=None, + overwrite=False, ): """ Create a new ZarrArray given an array identifier. It is assumed that @@ -145,6 +146,7 @@ def create_from_array_identifier( dtype, num_channels=num_channels, write_size=write_size, + delete=overwrite, ) zarr_dataset = zarr_container[array_identifier.dataset] zarr_dataset.attrs["offset"] = ( diff --git a/dacapo/predict.py b/dacapo/predict.py index a98ab1675..958a606f8 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -1,9 +1,15 @@ +from pathlib import Path + +import click +from dacapo.blockwise import run_blockwise +from dacapo.experiments import Run from dacapo.gp import DaCapoArraySource -from dacapo.experiments.model import Model -from dacapo.experiments.datasplits.datasets.arrays import Array +from dacapo.experiments import Model +from dacapo.store import create_config_store +from dacapo.store import create_weights_store from dacapo.store.local_array_store import LocalArrayIdentifier from dacapo.compute_context import LocalTorch, ComputeContext -from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray +from dacapo.experiments.datasplits.datasets.arrays import ZarrArray, Array from funlib.geometry import Coordinate, Roi import gunpowder as gp @@ -17,16 +23,94 @@ logger = logging.getLogger(__name__) -def predict( # TODO: MAKE THIS CLI ACCESSIBLE - model: Model, - raw_array: Array, - prediction_array_identifier: LocalArrayIdentifier, - num_cpu_workers: int = 4, - compute_context: ComputeContext = LocalTorch(), - output_roi: Optional[Roi] = None, - output_dtype: Optional[np.dtype] = np.uint8, # type: ignore - overwrite: bool = False, +@cli.command() +@click.option( + "-r", "--run-name", required=True, type=str, help="The name of the run to apply." +) +@click.option( + "-i", + "--iteration", + required=True, + type=int, + help="The training iteration of the model to use for prediction.", +) +@click.option( + "-ic", + "--input_container", + required=True, + type=click.Path(exists=True, file_okay=False), +) +@click.option("-id", "--input_dataset", required=True, type=str) +@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False)) +@click.option( + "-roi", + "--output_roi", + type=str, + required=False, + help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]", +) +@click.option("-w", "--num_workers", type=int, default=30) +@click.option("-dt", "--output_dtype", type=str, default="uint8") +@click.option( + "-cc", + "--compute_context", + type=str, + default="LocalTorch", + help="The compute context to use for prediction. Must be the name of a subclass of ComputeContext.", +) +@click.option("-ow", "--overwrite", is_flag=True) +def predict( + run_name: str, + iteration: int, + input_container: Path or str, + input_dataset: str, + output_path: Path or str, + output_roi: Optional[str | Roi] = None, + num_workers: int = 30, + output_dtype: np.dtype | str = np.uint8, # type: ignore + compute_context: ComputeContext | str = LocalTorch(), + overwrite: bool = True, ): + # retrieving run + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + + # get arrays + raw_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) + raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier) + output_container = Path( + output_path, + "".join(Path(input_container).name.split(".")[:-1]) + f".zarr", + ) # TODO: zarr hardcoded + prediction_array_identifier = LocalArrayIdentifier( + output_container, f"prediction_{run_name}_{iteration}" + ) + + if isinstance(output_roi, str): + start, end = zip( + *[ + tuple(int(coord) for coord in axis.split(":")) + for axis in output_roi.strip("[]").split(",") + ] + ) + output_roi = Roi( + Coordinate(start), + Coordinate(end) - Coordinate(start), + ) + + if output_roi is None: + output_roi = raw_array.roi + else: + output_roi = output_roi.snap_to_grid( + raw_array.voxel_size, mode="grow" + ).intersect(raw_array.roi) + + if isinstance(output_dtype, str): + output_dtype = np.dtype(output_dtype) + + model = run.model.eval() + # get the model's input and output size input_voxel_size = Coordinate(raw_array.voxel_size) @@ -59,77 +143,26 @@ def predict( # TODO: MAKE THIS CLI ACCESSIBLE model.num_out_channels, output_voxel_size, output_dtype, + overwrite=overwrite, ) - # create gunpowder keys - - raw = gp.ArrayKey("RAW") - prediction = gp.ArrayKey("PREDICTION") - - # assemble prediction pipeline - - # prepare data source - pipeline = DaCapoArraySource(raw_array, raw) - # raw: (c, d, h, w) - pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) - # raw: (c, d, h, w) - pipeline += gp.Unsqueeze([raw]) - # raw: (1, c, d, h, w) - - gt_padding = (output_size - output_roi.shape) % output_size - prediction_roi = output_roi.grow(gt_padding) # TODO: are we sure this makes sense? - # TODO: Add cache node? - # predict - pipeline += gp_torch.Predict( - model=model, - inputs={"x": raw}, - outputs={0: prediction}, - array_specs={ - prediction: gp.ArraySpec( - roi=prediction_roi, - voxel_size=output_voxel_size, - dtype=np.float32, # assumes network output is float32 - ) - }, - spawn_subprocess=False, - device=str(compute_context.device), - ) - # raw: (1, c, d, h, w) - # prediction: (1, [c,] d, h, w) - - # prepare writing - pipeline += gp.Squeeze([raw, prediction]) - # raw: (c, d, h, w) - # prediction: (c, d, h, w) - - # convert to uint8 if necessary: - if output_dtype == np.uint8: - pipeline += gp.IntensityScaleShift( - prediction, scale=255.0, shift=0.0 - ) # assumes float32 is [0,1] - pipeline += gp.AsType(prediction, output_dtype) - - # write to zarr - pipeline += gp.ZarrWrite( - {prediction: prediction_array_identifier.dataset}, - str(prediction_array_identifier.container.parent), - prediction_array_identifier.container.name, - dataset_dtypes={prediction: output_dtype}, + # run blockwise prediction + run_blockwise( + worker_file=str(Path(Path(__file__).parent, "blockwise", "predict_worker.py")), + compute_context=compute_context, + total_roi=output_roi, + read_roi=input_roi, + write_roi=output_roi, + num_workers=num_workers, + max_retries=2, # TODO: make this an option + timeout=None, # TODO: make this an option + ###### + run_name=run_name, + iteration=iteration, + raw_array_identifier=raw_array_identifier, + prediction_array_identifier=prediction_array_identifier, ) - # create reference batch request - ref_request = gp.BatchRequest() - ref_request.add(raw, input_size) - ref_request.add(prediction, output_size) - pipeline += gp.Scan( - ref_request - ) # TODO: This is a slow implementation for rendering - - # build pipeline and predict in complete output ROI - - with gp.build(pipeline): - pipeline.request_batch(gp.BatchRequest()) - container = zarr.open(str(prediction_array_identifier.container)) dataset = container[prediction_array_identifier.dataset] dataset.attrs["axes"] = ( # type: ignore