diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index fd22b426a..3f658d9ef 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -3,6 +3,10 @@ from dacapo.blockwise.scheduler import segment_blockwise from dacapo.experiments.datasplits.datasets.arrays import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier +from dacapo.utils.array_utils import to_ndarray, save_ndarray +from funlib.persistence import open_ds +import daisy +import mwatershed as mws from .watershed_post_processor_parameters import WatershedPostProcessorParameters from .post_processor import PostProcessor @@ -123,29 +127,15 @@ def process( np.uint64, block_size * self.prediction_array.voxel_size, ) + input_array = open_ds( + self.prediction_array_identifier.container.path, + self.prediction_array_identifier.dataset, + ) - read_roi = Roi((0, 0, 0), self.prediction_array.voxel_size * block_size) - # run blockwise prediction - pars = { - "offsets": self.offsets, - "bias": parameters.bias, - "context": parameters.context, - } - segment_blockwise( - segment_function_file=str( - Path(Path(dacapo.blockwise.__file__).parent, "watershed_function.py") - ), - context=parameters.context, - total_roi=self.prediction_array.roi, - read_roi=read_roi.grow(parameters.context, parameters.context), - write_roi=read_roi, - num_workers=num_workers, - max_retries=2, # TODO: make this an option - timeout=None, # TODO: make this an option - ###### - input_array_identifier=self.prediction_array_identifier, - output_array_identifier=output_array_identifier, - parameters=pars, + data = to_ndarray(input_array, output_array.roi) + segmentation = mws.agglom( + data - parameters.bias, offsets=self.offsets, randomized_strides=True ) + save_ndarray(segmentation, self.prediction_array.roi, output_array) - return output_array + return output_array_identifier diff --git a/docs/source/notebooks/minimal_tutorial.py b/docs/source/notebooks/minimal_tutorial.py index d78b5b66a..411612f4d 100644 --- a/docs/source/notebooks/minimal_tutorial.py +++ b/docs/source/notebooks/minimal_tutorial.py @@ -376,3 +376,42 @@ ax[snapshot, 2].imshow(prediction[prediction.shape[0] // 2]) ax[snapshot, 0].set_ylabel(f"Snapshot {snapshot_it}") plt.show() + +# %% +# Visualize validations +import zarr + +num_validations = run_config.num_iterations // run_config.validation_interval +fig, ax = plt.subplots(num_validations, 4, figsize=(10, 2 * num_validations)) + +# Set column titles +column_titles = ["Raw", "Ground Truth", "Prediction", "Segmentation"] +for col in range(len(column_titles)): + ax[0, col].set_title(column_titles[col]) + +for validation in range(1, num_validations + 1): + dataset = run.datasplit.validate[0].name + validation_it = validation * run_config.validation_interval + # break + raw = zarr.open( + f"/Users/pattonw/dacapo/example_run/validation.zarr/inputs/{dataset}/raw" + )[:] + gt = zarr.open( + f"/Users/pattonw/dacapo/example_run/validation.zarr/inputs/{dataset}/gt" + )[0] + pred_path = f"/Users/pattonw/dacapo/example_run/validation.zarr/{validation_it}/ds_{dataset}/prediction" + out_path = f"/Users/pattonw/dacapo/example_run/validation.zarr/{validation_it}/ds_{dataset}/output/WatershedPostProcessorParameters(id=2, bias=0.5, context=(32, 32, 32))" + output = zarr.open( + out_path + )[:] + prediction = zarr.open(pred_path)[0] + print(raw.shape, gt.shape, output.shape) + c = (raw.shape[1] - gt.shape[1]) // 2 + if c != 0: + raw = raw[:, c:-c, c:-c] + ax[validation - 1, 0].imshow(raw[raw.shape[0] // 2]) + ax[validation - 1, 1].imshow(gt[gt.shape[0] // 2]) + ax[validation - 1, 2].imshow(prediction[prediction.shape[0] // 2]) + ax[validation - 1, 3].imshow(output[output.shape[0] // 2]) + ax[validation - 1, 0].set_ylabel(f"Validation {validation_it}") +plt.show() diff --git a/pyproject.toml b/pyproject.toml index ade6bb47c..866c9fffe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,8 @@ docs = [ "sphinx-click", "sphinx-rtd-theme", "myst-parser", + "matplotlib", + "pooch", ] examples = [ "ipython",