diff --git a/dacapo/apply.py b/dacapo/apply.py index bfdb2c182..9e6006c9b 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -34,7 +34,7 @@ def apply( parameters: Optional[PostProcessorParameters | str] = None, roi: Optional[Roi | str] = None, num_workers: int = 12, - output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore + output_dtype: np.dtype | str = np.uint8, # type: ignore overwrite: bool = True, file_format: str = "zarr", ): @@ -92,7 +92,7 @@ def apply( logger.info( "Finding best parameters for validation dataset %s", _validation_dataset ) - parameters = run.task.evaluator.get_overall_best_parameters( # TODO + parameters = run.task.evaluator.get_overall_best_parameters( _validation_dataset, criterion ) assert ( @@ -102,10 +102,10 @@ def apply( elif isinstance(parameters, str): try: post_processor_name = parameters.split("(")[0] - post_processor_kwargs = parameters.split("(")[1].strip(")").split(",") + _post_processor_kwargs = parameters.split("(")[1].strip(")").split(",") post_processor_kwargs = { key.strip(): value.strip() - for key, value in [arg.split("=") for arg in post_processor_kwargs] + for key, value in [arg.split("=") for arg in _post_processor_kwargs] } for key, value in post_processor_kwargs.items(): if value.isdigit(): @@ -132,12 +132,12 @@ def apply( ), "Parameters must be parsable to a PostProcessorParameters object." # make array identifiers for input, predictions and outputs - input_array_identifier = LocalArrayIdentifier(input_container, input_dataset) + input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) input_array = ZarrArray.open_from_array_identifier(input_array_identifier) if roi is None: - roi = input_array.roi + _roi = input_array.roi else: - roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect( + _roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect( input_array.roi ) output_container = Path( @@ -164,7 +164,7 @@ def apply( input_array_identifier, prediction_array_identifier, output_array_identifier, - roi, + _roi, num_workers, output_dtype, overwrite, diff --git a/dacapo/blockwise/segment_worker.py b/dacapo/blockwise/segment_worker.py index 32c86cacb..ecea9dee0 100644 --- a/dacapo/blockwise/segment_worker.py +++ b/dacapo/blockwise/segment_worker.py @@ -84,7 +84,9 @@ def start_worker( segmentation = segment_function(input_array, block, **parameters) - assert segmentation.dtype == np.uint64 + assert ( + segmentation.dtype == np.uint64 + ), "Instance segmentations is expected to be uint64" id_bump = block.block_id[1] * num_voxels_in_block segmentation += id_bump @@ -139,8 +141,8 @@ def start_worker( ) unique_pairs = np.concatenate(unique_pairs) - zero_u = unique_pairs[:, 0] == 0 - zero_v = unique_pairs[:, 1] == 0 + zero_u = unique_pairs[:, 0] == 0 # type: ignore + zero_v = unique_pairs[:, 1] == 0 # type: ignore non_zero_filter = np.logical_not(np.logical_or(zero_u, zero_v)) edges = unique_pairs[non_zero_filter] diff --git a/dacapo/cli.py b/dacapo/cli.py index cda08e40f..f459dced0 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -6,11 +6,16 @@ import dacapo import click import logging -from daisy import Roi +from funlib.geometry import Roi, Coordinate +from funlib.persistence import open_ds from dacapo.experiments.datasplits.datasets.dataset import Dataset from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( PostProcessorParameters, ) +from dacapo.blockwise import ( + run_blockwise as _run_blockwise, + segment_blockwise as _segment_blockwise, +) @click.group() @@ -88,7 +93,7 @@ def apply( parameters: Optional[PostProcessorParameters | str] = None, roi: Optional[Roi | str] = None, num_workers: int = 30, - output_dtype: Optional[np.dtype | str] = "uint8", + output_dtype: np.dtype | str = "uint8", overwrite: bool = True, ): dacapo.apply( @@ -158,3 +163,253 @@ def predict( output_dtype, overwrite, ) + + +@cli.command( + context_settings=dict( + ignore_unknown_options=True, + allow_extra_args=True, + ) +) +@click.option( + "-ic", + "--input_container", + required=True, + type=click.Path(exists=True, file_okay=False), +) +@click.option("-id", "--input_dataset", required=True, type=str) +@click.option( + "-oc", "--output_container", required=True, type=click.Path(file_okay=False) +) +@click.option("-od", "--output_dataset", required=True, type=str) +@click.option( + "-w", "--worker_file", required=True, type=str, help="The path to the worker file." +) +@click.option( + "-tr", + "--total_roi", + required=True, + type=str, + help="The total roi to be processed. Format is [start:end, start:end, ... ]", + default=None, +) +@click.option( + "-rr", + "--read_roi_size", + required=True, + type=str, + help="The size of the roi to be read for each block.", +) +@click.option( + "-wr", + "--write_roi_size", + required=True, + type=str, + help="The size of the roi to be written for each block.", +) +@click.option("-nw", "--num_workers", type=int, default=16) +@click.option("-mr", "--max_retries", type=int, default=2) +@click.option("-t", "--timeout", type=int, default=None) +@click.pass_context +def run_blockwise( + ctx, + input_container: Path | str, + input_dataset: str, + output_container: Path | str, + output_dataset: str, + worker_file: str | Path, + total_roi: str | None, + read_roi_size: str, + write_roi_size: str, + num_workers: int = 16, + max_retries: int = 2, + timeout=None, + *args, + **kwargs, +): + # get arubtrary args and kwargs + kwargs = unpack_ctx(ctx) + + if total_roi is not None: + # parse the string into a Roi + start, end = zip( + *[ + tuple(int(coord) for coord in axis.split(":")) + for axis in total_roi.strip("[]").split(",") + ] + ) + _total_roi = Roi( + Coordinate(start), + Coordinate(end) - Coordinate(start), + ) + else: + input_ds = open_ds(str(input_container), input_dataset) + _total_roi = input_ds.roi + + read_roi = Roi([0, 0, 0], [int(coord) for coord in read_roi_size.split(",")]) + # Find different between read and write roi + context = (np.array(write_roi_size) - np.array(read_roi_size)) // 2 + write_roi = Roi( + context, + [int(coord) for coord in write_roi_size.split(",")], + ) + + _run_blockwise( + input_container=input_container, + input_dataset=input_dataset, + output_container=output_container, + output_dataset=output_dataset, + worker_file=worker_file, + total_roi=_total_roi, + read_roi=read_roi, + write_roi=write_roi, + num_workers=num_workers, + max_retries=max_retries, + timeout=timeout, + *args, + **kwargs, + ) + + +@cli.command( + context_settings=dict( + ignore_unknown_options=True, + allow_extra_args=True, + ) +) +@click.option( + "-ic", + "--input_container", + required=True, + type=click.Path(exists=True, file_okay=False), +) +@click.option("-id", "--input_dataset", required=True, type=str) +@click.option( + "-oc", "--output_container", required=True, type=click.Path(file_okay=False) +) +@click.option("-od", "--output_dataset", required=True, type=str) +@click.option("-sf", "--segment_function_file", required=True, type=click.Path()) +@click.option( + "-c", + "--context", + type=str, + help="The context to be used, in the format of [x,y,z]. Defaults to the difference between the read and write rois.", + default=None, +) +@click.option( + "-tr", + "--total_roi", + type=str, + help="The total roi to be processed. Format is [start:end,start:end,...] Defaults to the roi of the input dataset. Do not use spaces in CLI argument.", + default=None, +) +@click.option( + "-rr", + "--read_roi_size", + required=True, + type=str, + help="The size of the roi to be read for each block, in the format of [x,y,z].", +) +@click.option( + "-wr", + "--write_roi_size", + required=True, + type=str, + help="The size of the roi to be written for each block, in the format of [x,y,z].", +) +@click.option("-nw", "--num_workers", type=int, default=16) +@click.option("-mr", "--max_retries", type=int, default=2) +@click.option("-t", "--timeout", type=int, default=None) +@click.option("-tp", "--tmp_prefix", type=str, default="tmp") +@click.pass_context +def segment_blockwise( + ctx, + input_container: Path | str, + input_dataset: str, + output_container: Path | str, + output_dataset: str, + segment_function_file: Path | str, + context: str | None, + total_roi: str, + read_roi_size: str, + write_roi_size: str, + num_workers: int = 16, + max_retries: int = 2, + timeout=None, + tmp_prefix: str = "tmp", + *args, + **kwargs, +): + # get arubtrary args and kwargs + kwargs = unpack_ctx(ctx) + + if total_roi is not None: + # parse the string into a Roi + start, end = zip( + *[ + tuple(int(coord) for coord in axis.split(":")) + for axis in total_roi.strip("[]").split(",") + ] + ) + _total_roi = Roi( + Coordinate(start), + Coordinate(end) - Coordinate(start), + ) + else: + input_ds = open_ds(str(input_container), input_dataset) + _total_roi = input_ds.roi + + read_roi = Roi([0, 0, 0], [int(coord) for coord in read_roi_size.split(",")]) + # Find different between read and write roi + _context = (np.array(write_roi_size) - np.array(read_roi_size)) // 2 + write_roi = Roi( + _context, + [int(coord) for coord in write_roi_size.split(",")], + ) + + if context is None: + context = _context + + _segment_blockwise( + input_container=input_container, + input_dataset=input_dataset, + output_container=output_container, + output_dataset=output_dataset, + segment_function_file=segment_function_file, + context=_context, + total_roi=_total_roi, + read_roi=read_roi, + write_roi=write_roi, + num_workers=num_workers, + max_retries=max_retries, + timeout=timeout, + tmp_prefix=tmp_prefix, + *args, + **kwargs, + ) + + +def unpack_ctx(ctx): + kwargs = { + ctx.args[i].lstrip("-"): ctx.args[i + 1] for i in range(0, len(ctx.args), 2) + } + for k, v in kwargs.items(): + print(k, v) + if v.isnumeric(): + if "." in v: + kwargs[k] = float(v) + else: + kwargs[k] = int(v) + return kwargs + + +@cli.command( + context_settings=dict( + ignore_unknown_options=True, + allow_extra_args=True, + ) +) +@click.pass_context +def test(ctx): + print(ctx.args) + print(unpack_ctx(ctx)) diff --git a/dacapo/predict.py b/dacapo/predict.py index ee0dcaa2b..d0db9149f 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -22,7 +22,7 @@ def predict( iteration: int, input_container: Path | str, input_dataset: str, - output_path: LocalArrayIdentifier | str, + output_path: LocalArrayIdentifier | Path | str, output_roi: Optional[Roi | str] = None, num_workers: int = 12, output_dtype: np.dtype | str = np.uint8, # type: ignore @@ -101,7 +101,7 @@ def predict( output_roi = output_roi.snap_to_grid( raw_array.voxel_size, mode="grow" ).intersect(raw_array.roi.grow(-context, -context)) - _input_roi = output_roi.grow(context, context) + _input_roi = output_roi.grow(context, context) # type: ignore if isinstance(output_dtype, str): output_dtype = np.dtype(output_dtype) diff --git a/pyproject.toml b/pyproject.toml index a3f0b0015..4752e6d1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,10 @@ examples = [ "ipykernel", "jupyter", ] -all = ["dacapo-ml[test,dev,docs,examples]"] +pretrained = [ + "empanada-napari", + ] +all = ["dacapo-ml[test,dev,docs,examples,pretrained]"] [project.urls] homepage = "https://github.io/janelia-cellmap/dacapo" @@ -151,8 +154,7 @@ disallow_subclassing_any = false show_error_codes = true pretty = true exclude = [ - "dacapo/apply.py", - "dacapo/cli.py" + "scratch/*", ] @@ -183,10 +185,12 @@ module = [ "neuclease.dvid.*", "mwatershed.*", "numpy_indexed.*", + "empanada_napari.*", + "napari.*", + "empanada.*", ] ignore_missing_imports = true - # https://coverage.readthedocs.io/en/6.4/config.html [tool.coverage.report] exclude_lines = [ diff --git a/tests/components/test_options.py b/tests/components/test_options.py index 7ac7e1488..1a228791c 100644 --- a/tests/components/test_options.py +++ b/tests/components/test_options.py @@ -1,3 +1,4 @@ +import os from dacapo import Options @@ -11,6 +12,11 @@ def test_no_config(): if config_file.exists(): config_file.unlink() + # Remove the environment variable + env_dict = dict(os.environ) + if "OPTIONS_FILE" in env_dict: + del env_dict["OPTIONS_FILE"] + # Parse the options options = Options.instance()