Skip to content

Commit

Permalink
feat: ✨ Add predict functionality to dacapo CLI module
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Feb 13, 2024
1 parent 044611b commit 48de078
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 24 deletions.
1 change: 1 addition & 0 deletions dacapo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .apply import apply # noqa
from .train import train # noqa
from .validate import validate # noqa
from .predict import predict # noqa
2 changes: 1 addition & 1 deletion dacapo/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def apply(
parameters: Optional[PostProcessorParameters or str] = None,
roi: Optional[Roi or str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype or str] = np.uint8, # type: ignore
output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
file_format: str = "zarr",
Expand Down
32 changes: 16 additions & 16 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,22 +237,22 @@ def spawn_worker(
"""
# Make the command for the worker to run
command = [ # TODO
"dacapo",
"predict",
"--name",
model.name,
"--criterion",
"best",
"--channels",
"0",
"--out_container",
prediction_array_identifier.container,
"--out_dataset",
prediction_array_identifier.dataset,
"--in_container",
raw_array.container,
"--in_dataset",
raw_array.dataset,
# "dacapo",
# "predict",
# "--name",
# model.name,
# "--criterion",
# "best",
# "--channels",
# "0",
# "--out_container",
# prediction_array_identifier.container,
# "--out_dataset",
# prediction_array_identifier.dataset,
# "--in_container",
# raw_array.container,
# "--in_dataset",
# raw_array.dataset,
]

def run_worker():
Expand Down
142 changes: 136 additions & 6 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
from pathlib import Path
from typing import Optional

import numpy as np

import dacapo
import click
import logging
from daisy import Roi, Coordinate
from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray
from dacapo.experiments.datasplits.datasets.dataset import Dataset
from dacapo.experiments.run import Run
from dacapo.experiments.tasks.post_processors.post_processor_parameters import (
PostProcessorParameters,
)
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.store.create_store import create_config_store, create_weights_store
from dacapo import compute_context
from dacapo.compute_context import ComputeContext, LocalTorch


@click.group()
Expand Down Expand Up @@ -65,19 +79,26 @@ def validate(run_name, iteration):
)
@click.option("-w", "--num_cpu_workers", type=int, default=30)
@click.option("-dt", "--output_dtype", type=str, default="uint8")
@click.option("-ow", "--overwrite", is_flag=True)
@click.option("-cc", "--compute_context", type=str, default="LocalTorch")
def apply(
run_name: str,
input_container: str,
input_container: Path or str,
input_dataset: str,
output_path: str,
validation_dataset: Optional[str] = None,
output_path: Path or str,
validation_dataset: Optional[Dataset or str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[str] = None,
roi: Optional[str] = None,
parameters: Optional[PostProcessorParameters or str] = None,
roi: Optional[Roi or str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[str] = "uint8",
output_dtype: Optional[np.dtype | str] = "uint8",
overwrite: bool = True,
compute_context: Optional[ComputeContext | str] = LocalTorch(),
):
if isinstance(compute_context, str):
compute_context = getattr(compute_context, compute_context)()

dacapo.apply(
run_name,
input_container,
Expand All @@ -90,4 +111,113 @@ def apply(
roi,
num_cpu_workers,
output_dtype,
overwrite=overwrite,
compute_context=compute_context, # type: ignore
)


@cli.command()
@click.option(
"-r", "--run-name", required=True, type=str, help="The name of the run to apply."
)
@click.option(
"-i",
"--iteration",
required=True,
type=int,
help="The training iteration of the model to use for prediction.",
)
@click.option(
"-ic",
"--input_container",
required=True,
type=click.Path(exists=True, file_okay=False),
)
@click.option("-id", "--input_dataset", required=True, type=str)
@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False))
@click.option(
"-roi",
"--roi",
type=str,
required=False,
help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]",
)
@click.option("-w", "--num_cpu_workers", type=int, default=30)
@click.option("-dt", "--output_dtype", type=str, default="uint8")
@click.option(
"-cc",
"--compute_context",
type=str,
default="LocalTorch",
help="The compute context to use for prediction. Must be the name of a subclass of ComputeContext.",
)
@click.option("-ow", "--overwrite", is_flag=True)
def predict(
run_name: str,
iteration: int,
input_container: Path or str,
input_dataset: str,
output_path: Path or str,
roi: Optional[str | Roi] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore
compute_context: Optional[ComputeContext | str] = LocalTorch(),
overwrite: bool = True,
):
# retrieving run
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

# create weights store
weights_store = create_weights_store()

# load weights
weights_store.retrieve_weights(run_name, iteration)

# get arrays
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
output_container = Path(
output_path,
"".join(Path(input_container).name.split(".")[:-1]) + f".zarr",
) # TODO: zarr hardcoded
prediction_array_identifier = LocalArrayIdentifier(
output_container, f"prediction_{run_name}_{iteration}"
)

if isinstance(roi, str):
start, end = zip(
*[
tuple(int(coord) for coord in axis.split(":"))
for axis in roi.strip("[]").split(",")
]
)
roi = Roi(
Coordinate(start),
Coordinate(end) - Coordinate(start),
)

if roi is None:
roi = input_array.roi
else:
roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect(
input_array.roi
)

if isinstance(output_dtype, str):
output_dtype = np.dtype(output_dtype)

if isinstance(compute_context, str):
compute_context = getattr(compute_context, compute_context)()

dacapo.predict(
run.model,
input_array,
prediction_array_identifier,
output_roi=roi,
num_cpu_workers=num_cpu_workers,
output_dtype=output_dtype,
compute_context=compute_context, # type: ignore
overwrite=overwrite,
)
2 changes: 1 addition & 1 deletion dacapo/compute_context/bsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@attr.s
class Bsub(ComputeContext):
class Bsub(ComputeContext): # TODO: Load defaults from dacapo.yaml
queue: str = attr.ib(default="local", metadata={"help_text": "The queue to run on"})
num_gpus: int = attr.ib(
default=1,
Expand Down

0 comments on commit 48de078

Please sign in to comment.