Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Empanada #142

Merged
merged 7 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 280 additions & 0 deletions dacapo/blockwise/empanada_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
import numpy as np

import os

try:
from empanada_napari.inference import Engine3d
from empanada_napari.multigpu import MultiGPUEngine3d
from empanada_napari.utils import get_configs
from empanada.config_loaders import read_yaml
from empanada_napari.inference import (
InstanceTracker,
get_axis_trackers_by_class,
instance_relabel,
filters,
fill_volume,
create_instance_consensus,
create_semantic_consensus,
)
except ImportError:
raise ImportError("Please install empanada-napari to use this CLI")


default_parameters = {
"model_config": "MitoNet_v1",
"use_gpu": True,
"use_quantized": False,
"multigpu": False,
"downsampling": 1,
"confidence_thr": 0.5,
"center_confidence_thr": 0.1,
"min_distance_object_centers": 21,
"fine_boundaries": True,
"semantic_only": False,
"median_slices": 11,
"min_size": 10000,
"min_extent": 50,
"maximum_objects_per_class": 1000000,
"inference_plane": "xy",
"orthoplane": True,
"return_panoptic": False,
"pixel_vote_thr": 1,
"allow_one_view": False,
}


def segment_function(input_array, block, **parameters):
vols, class_names = [], []
for vol, class_name, _ in empanada_segmenter(
input_array[block.read_roi], **parameters
):
vols.append(vol[None, ...].astype(np.uint64))
class_names.append(class_name)
return np.concatenate(vols, axis=0, dtype=np.uint64).squeeze()


# THESE ARE NON-THREAD WORKER VERSIONS OF THE FUNCTIONS, adapted from empanada-napari in by Jeff Rhoades (HHMI Janelia) February 2024

model_configs = get_configs()


def stack_inference(engine, volume, axis_name):
stack, trackers = engine.infer_on_axis(volume, axis_name)
trackers_dict = {axis_name: trackers}
return stack, axis_name, trackers_dict


def orthoplane_inference(engine, volume):
trackers_dict = {}
for axis_name in ["xy", "xz", "yz"]:
stack, trackers = engine.infer_on_axis(volume, axis_name)
trackers_dict[axis_name] = trackers

# report instances per class
for tracker in trackers:
class_id = tracker.class_id
print(
f"Class {class_id}, axis {axis_name}, has {len(tracker.instances.keys())} instances"
)

return trackers_dict


def empanada_segmenter(
image,
model_config="MitoNet_v1",
use_gpu=True,
use_quantized=False,
multigpu=False,
downsampling=1,
confidence_thr=0.5,
center_confidence_thr=0.1,
min_distance_object_centers=21,
fine_boundaries=True,
semantic_only=False,
median_slices=11,
min_size=10000,
min_extent=50,
maximum_objects_per_class=1000000,
inference_plane="xy",
orthoplane=True,
return_panoptic=False,
pixel_vote_thr=1,
allow_one_view=False,
):
# load the model config
model_config = read_yaml(model_configs[model_config])
min_size = int(min_size)
min_extent = int(min_extent)
maximum_objects_per_class = int(maximum_objects_per_class)

if multigpu:
engine = MultiGPUEngine3d(
model_config,
inference_scale=downsampling,
median_kernel_size=median_slices,
nms_kernel=min_distance_object_centers,
nms_threshold=center_confidence_thr,
confidence_thr=confidence_thr,
min_size=min_size,
min_extent=min_extent,
fine_boundaries=fine_boundaries,
label_divisor=maximum_objects_per_class,
semantic_only=semantic_only,
save_panoptic=return_panoptic,
)
# conditions where model needs to be (re)loaded
else:
engine = Engine3d(
model_config,
inference_scale=downsampling,
median_kernel_size=median_slices,
nms_kernel=min_distance_object_centers,
nms_threshold=center_confidence_thr,
confidence_thr=confidence_thr,
min_size=min_size,
min_extent=min_extent,
fine_boundaries=fine_boundaries,
label_divisor=maximum_objects_per_class,
use_gpu=use_gpu,
use_quantized=use_quantized,
semantic_only=semantic_only,
save_panoptic=return_panoptic,
)

def start_postprocess_worker(*args):
trackers_dict = args[0][2]
for vol, class_name, tracker in stack_postprocessing(
trackers_dict,
model_config,
label_divisor=maximum_objects_per_class,
min_size=min_size,
min_extent=min_extent,
dtype=engine.dtype,
):
print(f"Yielding {class_name} volume of shape {vol.shape}")
yield vol, class_name, tracker

def start_consensus_worker(trackers_dict):
for vol, class_name, tracker in tracker_consensus(
trackers_dict,
model_config,
pixel_vote_thr=pixel_vote_thr,
allow_one_view=allow_one_view,
min_size=min_size,
min_extent=min_extent,
dtype=engine.dtype,
):
print(f"Yielding {class_name} volume of shape {vol.shape}")
yield vol, class_name, tracker

# verify that the image doesn't have extraneous channel dimensions
assert image.ndim in [3, 4], "Only 3D and 4D input images can be handled!"
if image.ndim == 4:
# channel dimensions are commonly 1, 3 and 4
# check for dimensions on zeroeth and last axes
shape = image.shape
if shape[0] in [1, 3, 4]:
image = image[0]
elif shape[-1] in [1, 3, 4]:
image = image[..., 0]
else:
raise Exception(f"Image volume must be 3D, got image of shape {shape}")

print(
f"Got 4D image of shape {shape}, extracted single channel of size {image.shape}"
)

if orthoplane:
trackers_dict = orthoplane_inference(engine, image)
return start_consensus_worker(trackers_dict)
else:
outputs = stack_inference(engine, image, inference_plane)
return start_postprocess_worker(*outputs)


def stack_postprocessing(
trackers,
model_config,
label_divisor=1000,
min_size=200,
min_extent=4,
dtype=np.uint32,
):
r"""Relabels and filters each class defined in trackers. Yields a numpy
or zarr volume along with the name of the class that is segmented.
"""
thing_list = model_config["thing_list"]
class_names = model_config["class_names"]

# create the final instance segmentations
for class_id, class_name in class_names.items():
print(f"Creating stack segmentation for class {class_name}...")

class_tracker = get_axis_trackers_by_class(trackers, class_id)[0]
shape3d = class_tracker.shape3d

# merge instances from orthoplane inference
stack_tracker = InstanceTracker(class_id, label_divisor, shape3d, "xy")
stack_tracker.instances = instance_relabel(class_tracker)

# inplace apply filters to final merged segmentation
if class_id in thing_list:
filters.remove_small_objects(stack_tracker, min_size=min_size)
filters.remove_pancakes(stack_tracker, min_span=min_extent)

print(f"Total {class_name} objects {len(stack_tracker.instances.keys())}")

# decode and fill the instances
stack_vol = np.zeros(shape3d, dtype=dtype)

fill_volume(stack_vol, stack_tracker.instances)

yield stack_vol, class_name, stack_tracker.instances


def tracker_consensus(
trackers,
model_config,
pixel_vote_thr=2,
cluster_iou_thr=0.75,
allow_one_view=False,
min_size=200,
min_extent=4,
dtype=np.uint32,
):
r"""Calculate the orthoplane consensus from trackers. Yields a numpy
or zarr volume along with the name of the class that is segmented.
"""
labels = model_config["labels"]
thing_list = model_config["thing_list"]
class_names = model_config["class_names"]

# create the final instance segmentations
for class_id, class_name in class_names.items():
# get the relevant trackers for the class_label
print(f"Creating consensus segmentation for class {class_name}...")

class_trackers = get_axis_trackers_by_class(trackers, class_id)
shape3d = class_trackers[0].shape3d

# consensus from orthoplane
if class_id in thing_list:
consensus_tracker = create_instance_consensus(
class_trackers, pixel_vote_thr, cluster_iou_thr, allow_one_view
)
filters.remove_small_objects(consensus_tracker, min_size=min_size)
filters.remove_pancakes(consensus_tracker, min_span=min_extent)
else:
consensus_tracker = create_semantic_consensus(
class_trackers, pixel_vote_thr
)

print(f"Total {class_name} objects {len(consensus_tracker.instances.keys())}")

# decode and fill the instances
consensus_vol = np.zeros(shape3d, dtype=dtype)

fill_volume(consensus_vol, consensus_tracker.instances)

yield consensus_vol, class_name, consensus_tracker.instances
19 changes: 10 additions & 9 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,16 @@ def start_worker(
def spawn_worker(
run_name: str,
iteration: int,
raw_array_identifier: "LocalArrayIdentifier",
prediction_array_identifier: "LocalArrayIdentifier",
input_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
):
"""Spawn a worker to predict on a given dataset.

Args:
model (Model): The model to use for prediction.
raw_array (Array): The raw data to predict on.
prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
run_name (str): The name of the run to apply.
iteration (int): The training iteration of the model to use for prediction.
input_array_identifier (LocalArrayIdentifier): The raw data to predict on.
output_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
"""
compute_context = create_compute_context()

Expand All @@ -201,13 +202,13 @@ def spawn_worker(
"--iteration",
iteration,
"--input_container",
raw_array_identifier.container,
input_array_identifier.container,
"--input_dataset",
raw_array_identifier.dataset,
input_array_identifier.dataset,
"--output_container",
prediction_array_identifier.container,
output_array_identifier.container,
"--output_dataset",
prediction_array_identifier.dataset,
output_array_identifier.dataset,
"--device",
str(compute_context.device),
]
Expand Down
1 change: 0 additions & 1 deletion dacapo/blockwise/relabel_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def start_worker(
nodes, edges = read_cross_block_merges(tmpdir)

components = find_components(nodes, edges)
components = DisjointSet(nodes, edges)

while True:
with client.acquire_block() as block:
Expand Down
10 changes: 4 additions & 6 deletions dacapo/blockwise/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def segment_blockwise(
):
with tempfile.TemporaryDirectory(prefix=tmp_prefix) as tmpdir:
# write parameters to tmpdir
if "parameters" in locals():
if "parameters" in kwargs:
with open(Path(tmpdir, "parameters.yaml"), "w") as f:
yaml.dump(locals()["parameters"], f)
yaml.dump(kwargs.pop("parameters"), f)

# Make the task
task = DaCapoBlockwiseTask(
Expand All @@ -112,19 +112,17 @@ def segment_blockwise(
timeout,
upstream_tasks,
tmpdir=tmpdir,
function_path=segment_function_file,
function_path=str(segment_function_file),
*args,
**kwargs,
)

daisy.run_blockwise([task])
success = daisy.run_blockwise([task])

# give a second for the fist task to finish
time.sleep(1)
read_roi = write_roi

success = daisy.run_blockwise([task])

# Make the task
task = DaCapoBlockwiseTask(
str(Path(Path(dacapo.blockwise.__file__).parent, "relabel_worker.py")),
Expand Down
4 changes: 2 additions & 2 deletions dacapo/blockwise/segment_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def start_worker(

assert (
segmentation.dtype == np.uint64
), "Instance segmentations is expected to be uint64"
), "Instance segmentations returned by segment_function is expected to be uint64"

id_bump = block.block_id[1] * num_voxels_in_block
segmentation += id_bump
Expand All @@ -98,7 +98,7 @@ def start_worker(
)

# store segmentation in out array
output_array._daisy_array[block.write_roi] = segmentation[block.write_roi]
output_array[block.write_roi] = segmentation[block.write_roi]

neighbor_roi = block.write_roi.grow(
input_array.voxel_size, input_array.voxel_size
Expand Down
Loading
Loading