Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/cli #141

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def apply(
parameters: Optional[PostProcessorParameters | str] = None,
roi: Optional[Roi | str] = None,
num_workers: int = 12,
output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore
output_dtype: np.dtype | str = np.uint8, # type: ignore
overwrite: bool = True,
file_format: str = "zarr",
):
Expand Down Expand Up @@ -92,7 +92,7 @@ def apply(
logger.info(
"Finding best parameters for validation dataset %s", _validation_dataset
)
parameters = run.task.evaluator.get_overall_best_parameters( # TODO
parameters = run.task.evaluator.get_overall_best_parameters(
_validation_dataset, criterion
)
assert (
Expand All @@ -102,10 +102,10 @@ def apply(
elif isinstance(parameters, str):
try:
post_processor_name = parameters.split("(")[0]
post_processor_kwargs = parameters.split("(")[1].strip(")").split(",")
_post_processor_kwargs = parameters.split("(")[1].strip(")").split(",")
post_processor_kwargs = {
key.strip(): value.strip()
for key, value in [arg.split("=") for arg in post_processor_kwargs]
for key, value in [arg.split("=") for arg in _post_processor_kwargs]
}
for key, value in post_processor_kwargs.items():
if value.isdigit():
Expand All @@ -132,12 +132,12 @@ def apply(
), "Parameters must be parsable to a PostProcessorParameters object."

# make array identifiers for input, predictions and outputs
input_array_identifier = LocalArrayIdentifier(input_container, input_dataset)
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
if roi is None:
roi = input_array.roi
_roi = input_array.roi
else:
roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect(
_roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect(
input_array.roi
)
output_container = Path(
Expand All @@ -164,7 +164,7 @@ def apply(
input_array_identifier,
prediction_array_identifier,
output_array_identifier,
roi,
_roi,
num_workers,
output_dtype,
overwrite,
Expand Down
8 changes: 5 additions & 3 deletions dacapo/blockwise/segment_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def start_worker(

segmentation = segment_function(input_array, block, **parameters)

assert segmentation.dtype == np.uint64
assert (
segmentation.dtype == np.uint64
), "Instance segmentations is expected to be uint64"

id_bump = block.block_id[1] * num_voxels_in_block
segmentation += id_bump
Expand Down Expand Up @@ -139,8 +141,8 @@ def start_worker(
)

unique_pairs = np.concatenate(unique_pairs)
zero_u = unique_pairs[:, 0] == 0
zero_v = unique_pairs[:, 1] == 0
zero_u = unique_pairs[:, 0] == 0 # type: ignore
zero_v = unique_pairs[:, 1] == 0 # type: ignore
non_zero_filter = np.logical_not(np.logical_or(zero_u, zero_v))

edges = unique_pairs[non_zero_filter]
Expand Down
259 changes: 257 additions & 2 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
import dacapo
import click
import logging
from daisy import Roi
from funlib.geometry import Roi, Coordinate
from funlib.persistence import open_ds
from dacapo.experiments.datasplits.datasets.dataset import Dataset
from dacapo.experiments.tasks.post_processors.post_processor_parameters import (
PostProcessorParameters,
)
from dacapo.blockwise import (
run_blockwise as _run_blockwise,
segment_blockwise as _segment_blockwise,
)


@click.group()
Expand Down Expand Up @@ -88,7 +93,7 @@ def apply(
parameters: Optional[PostProcessorParameters | str] = None,
roi: Optional[Roi | str] = None,
num_workers: int = 30,
output_dtype: Optional[np.dtype | str] = "uint8",
output_dtype: np.dtype | str = "uint8",
overwrite: bool = True,
):
dacapo.apply(
Expand Down Expand Up @@ -158,3 +163,253 @@ def predict(
output_dtype,
overwrite,
)


@cli.command(
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
)
)
@click.option(
"-ic",
"--input_container",
required=True,
type=click.Path(exists=True, file_okay=False),
)
@click.option("-id", "--input_dataset", required=True, type=str)
@click.option(
"-oc", "--output_container", required=True, type=click.Path(file_okay=False)
)
@click.option("-od", "--output_dataset", required=True, type=str)
@click.option(
"-w", "--worker_file", required=True, type=str, help="The path to the worker file."
)
@click.option(
"-tr",
"--total_roi",
required=True,
type=str,
help="The total roi to be processed. Format is [start:end, start:end, ... ]",
default=None,
)
@click.option(
"-rr",
"--read_roi_size",
required=True,
type=str,
help="The size of the roi to be read for each block.",
)
@click.option(
"-wr",
"--write_roi_size",
required=True,
type=str,
help="The size of the roi to be written for each block.",
)
@click.option("-nw", "--num_workers", type=int, default=16)
@click.option("-mr", "--max_retries", type=int, default=2)
@click.option("-t", "--timeout", type=int, default=None)
@click.pass_context
def run_blockwise(
ctx,
input_container: Path | str,
input_dataset: str,
output_container: Path | str,
output_dataset: str,
worker_file: str | Path,
total_roi: str | None,
read_roi_size: str,
write_roi_size: str,
num_workers: int = 16,
max_retries: int = 2,
timeout=None,
*args,
**kwargs,
):
# get arubtrary args and kwargs
kwargs = unpack_ctx(ctx)

if total_roi is not None:
# parse the string into a Roi
start, end = zip(
*[
tuple(int(coord) for coord in axis.split(":"))
for axis in total_roi.strip("[]").split(",")
]
)
_total_roi = Roi(
Coordinate(start),
Coordinate(end) - Coordinate(start),
)
else:
input_ds = open_ds(str(input_container), input_dataset)
_total_roi = input_ds.roi

read_roi = Roi([0, 0, 0], [int(coord) for coord in read_roi_size.split(",")])
# Find different between read and write roi
context = (np.array(write_roi_size) - np.array(read_roi_size)) // 2
write_roi = Roi(
context,
[int(coord) for coord in write_roi_size.split(",")],
)

_run_blockwise(
input_container=input_container,
input_dataset=input_dataset,
output_container=output_container,
output_dataset=output_dataset,
worker_file=worker_file,
total_roi=_total_roi,
read_roi=read_roi,
write_roi=write_roi,
num_workers=num_workers,
max_retries=max_retries,
timeout=timeout,
*args,
**kwargs,
)


@cli.command(
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
)
)
@click.option(
"-ic",
"--input_container",
required=True,
type=click.Path(exists=True, file_okay=False),
)
@click.option("-id", "--input_dataset", required=True, type=str)
@click.option(
"-oc", "--output_container", required=True, type=click.Path(file_okay=False)
)
@click.option("-od", "--output_dataset", required=True, type=str)
@click.option("-sf", "--segment_function_file", required=True, type=click.Path())
@click.option(
"-c",
"--context",
type=str,
help="The context to be used, in the format of [x,y,z]. Defaults to the difference between the read and write rois.",
default=None,
)
@click.option(
"-tr",
"--total_roi",
type=str,
help="The total roi to be processed. Format is [start:end,start:end,...] Defaults to the roi of the input dataset. Do not use spaces in CLI argument.",
default=None,
)
@click.option(
"-rr",
"--read_roi_size",
required=True,
type=str,
help="The size of the roi to be read for each block, in the format of [x,y,z].",
)
@click.option(
"-wr",
"--write_roi_size",
required=True,
type=str,
help="The size of the roi to be written for each block, in the format of [x,y,z].",
)
@click.option("-nw", "--num_workers", type=int, default=16)
@click.option("-mr", "--max_retries", type=int, default=2)
@click.option("-t", "--timeout", type=int, default=None)
@click.option("-tp", "--tmp_prefix", type=str, default="tmp")
@click.pass_context
def segment_blockwise(
ctx,
input_container: Path | str,
input_dataset: str,
output_container: Path | str,
output_dataset: str,
segment_function_file: Path | str,
context: str | None,
total_roi: str,
read_roi_size: str,
write_roi_size: str,
num_workers: int = 16,
max_retries: int = 2,
timeout=None,
tmp_prefix: str = "tmp",
*args,
**kwargs,
):
# get arubtrary args and kwargs
kwargs = unpack_ctx(ctx)

if total_roi is not None:
# parse the string into a Roi
start, end = zip(
*[
tuple(int(coord) for coord in axis.split(":"))
for axis in total_roi.strip("[]").split(",")
]
)
_total_roi = Roi(
Coordinate(start),
Coordinate(end) - Coordinate(start),
)
else:
input_ds = open_ds(str(input_container), input_dataset)
_total_roi = input_ds.roi

read_roi = Roi([0, 0, 0], [int(coord) for coord in read_roi_size.split(",")])
# Find different between read and write roi
_context = (np.array(write_roi_size) - np.array(read_roi_size)) // 2
write_roi = Roi(
_context,
[int(coord) for coord in write_roi_size.split(",")],
)

if context is None:
context = _context

_segment_blockwise(
input_container=input_container,
input_dataset=input_dataset,
output_container=output_container,
output_dataset=output_dataset,
segment_function_file=segment_function_file,
context=_context,
total_roi=_total_roi,
read_roi=read_roi,
write_roi=write_roi,
num_workers=num_workers,
max_retries=max_retries,
timeout=timeout,
tmp_prefix=tmp_prefix,
*args,
**kwargs,
)


def unpack_ctx(ctx):
kwargs = {
ctx.args[i].lstrip("-"): ctx.args[i + 1] for i in range(0, len(ctx.args), 2)
}
for k, v in kwargs.items():
print(k, v)
if v.isnumeric():
if "." in v:
kwargs[k] = float(v)
else:
kwargs[k] = int(v)
return kwargs


@cli.command(
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
)
)
@click.pass_context
def test(ctx):
print(ctx.args)
print(unpack_ctx(ctx))
4 changes: 2 additions & 2 deletions dacapo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def predict(
iteration: int,
input_container: Path | str,
input_dataset: str,
output_path: LocalArrayIdentifier | str,
output_path: LocalArrayIdentifier | Path | str,
output_roi: Optional[Roi | str] = None,
num_workers: int = 12,
output_dtype: np.dtype | str = np.uint8, # type: ignore
Expand Down Expand Up @@ -101,7 +101,7 @@ def predict(
output_roi = output_roi.snap_to_grid(
raw_array.voxel_size, mode="grow"
).intersect(raw_array.roi.grow(-context, -context))
_input_roi = output_roi.grow(context, context)
_input_roi = output_roi.grow(context, context) # type: ignore

if isinstance(output_dtype, str):
output_dtype = np.dtype(output_dtype)
Expand Down
Loading
Loading