From 6a6ee1b2fab36c9e5545dedc2ed5b884c0285b45 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 23 Jul 2024 11:42:54 -0400 Subject: [PATCH 01/14] =?UTF-8?q?fix:=20=F0=9F=90=9B=20Fix=20validation=20?= =?UTF-8?q?model=20loading.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit "validate_run" must now be called to do validation from run name. The Run must be preloaded and passed to call "validate" directly. --- dacapo/__init__.py | 2 +- dacapo/blockwise/predict_worker.py | 2 +- dacapo/cli.py | 4 +-- dacapo/predict.py | 3 +- dacapo/validate.py | 54 +++++++++++++++++++----------- 5 files changed, 40 insertions(+), 25 deletions(-) diff --git a/dacapo/__init__.py b/dacapo/__init__.py index fcca5ce7b..f54a1e06d 100644 --- a/dacapo/__init__.py +++ b/dacapo/__init__.py @@ -5,6 +5,6 @@ from . import experiments, utils # noqa from .apply import apply # noqa from .train import train # noqa -from .validate import validate # noqa +from .validate import validate, validate_run # noqa from .predict import predict # noqa from .blockwise import run_blockwise, segment_blockwise # noqa diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index bdaa8fc53..9b5cdbf33 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -121,7 +121,7 @@ def start_worker_fn( run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - if iteration is not None: + if iteration is not None and compute_context.distribute_workers: # create weights store weights_store = create_weights_store() diff --git a/dacapo/cli.py b/dacapo/cli.py index f05c32012..2af9aea77 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -94,8 +94,8 @@ def train(run_name): @click.option("-w", "--num_workers", type=int, default=30) @click.option("-dt", "--output_dtype", type=str, default="uint8") @click.option("-ow", "--overwrite", is_flag=True) -def validate(run_name, iteration): - dacapo.validate(run_name, iteration) +def validate(run_name, iteration, num_workers, output_dtype, overwrite): + dacapo.validate_run(run_name, iteration, num_workers, output_dtype, overwrite) @cli.command() diff --git a/dacapo/predict.py b/dacapo/predict.py index 8197b2e63..09ac848cc 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -72,14 +72,13 @@ def predict( output_container, f"prediction_{run_name}_{iteration}" ) - # get the model's input and output size compute_context = create_compute_context() if isinstance(compute_context, LocalTorch): num_workers = 1 model = run.model.eval() - if iteration is not None: + if iteration is not None and not compute_context.distribute_workers: # create weights store weights_store = create_weights_store() diff --git a/dacapo/validate.py b/dacapo/validate.py index b49ffe4c1..398308df2 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -1,3 +1,4 @@ +from dacapo.compute_context import create_compute_context from .predict import predict from .experiments import Run, ValidationIterationScores from .experiments.datasplits.datasets.arrays import ZarrArray @@ -16,21 +17,41 @@ def validate_run( - run: Run, + run_name: str, iteration: int, num_workers: int = 1, output_dtype: str = "uint8", overwrite: bool = True, ): """ - validate_run is deprecated and will be removed in a future version. Please use validate instead. + Validate a run at a given iteration. Loads the weights from a previously + stored checkpoint. Returns the best parameters and scores for this + iteration. + + Args: + run: The name of the run to validate. + iteration: The iteration to validate. + num_workers: The number of workers to use for validation. + output_dtype: The dtype to use for the output arrays. + overwrite: Whether to overwrite existing output arrays + """ - warn( - "validate_run is deprecated and will be removed in a future version. Please use validate instead.", - DeprecationWarning, - ) + # Load the model and weights + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + compute_context = create_compute_context() + if iteration is not None and not compute_context.distribute_workers: + # create weights store + weights_store = create_weights_store() + + # load weights + run.model.load_state_dict( + weights_store.retrieve_weights(run_name, iteration).model + ) + return validate( - run_name=run, + run=run, iteration=iteration, num_workers=num_workers, output_dtype=output_dtype, @@ -39,7 +60,7 @@ def validate_run( def validate( - run_name: str | Run, + run: Run, iteration: int, num_workers: int = 1, output_dtype: str = "uint8", @@ -51,7 +72,7 @@ def validate( iteration. Args: - run_name: The name of the run to validate. + run: The run to validate. iteration: The iteration to validate. num_workers: The number of workers to use for validation. output_dtype: The dtype to use for the output arrays. @@ -61,18 +82,12 @@ def validate( Raises: ValueError: If the run does not have a validation dataset or the dataset does not have ground truth. Example: - validate("my_run", 1000) + validate(my_run, 1000) """ - print(f"Validating run {run_name} at iteration {iteration}...") + print(f"Validating run {run.name} at iteration {iteration}...") - if isinstance(run_name, Run): - run = run_name - run_name = run.name - else: - config_store = create_config_store() - run_config = config_store.retrieve_run_config(run_name) - run = Run(run_config) + run_name = run.name # read in previous training/validation stats stats_store = create_stats_store() @@ -170,9 +185,10 @@ def validate( prediction_array_identifier = array_store.validation_prediction_array( run.name, iteration, validation_dataset.name ) + compute_context = create_compute_context() predict( run, - iteration, + iteration if compute_context.distribute_workers else None, input_container=input_raw_array_identifier.container, input_dataset=input_raw_array_identifier.dataset, output_path=prediction_array_identifier, From e89e2ec0e88bdc05afa1eb8b1197c228de10ee21 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 29 Jul 2024 13:10:49 -0400 Subject: [PATCH 02/14] support global_run for local compute context, solve tcp buffer error --- dacapo/blockwise/__init__.py | 1 + dacapo/blockwise/global_vars.py | 2 + dacapo/blockwise/predict_worker.py | 181 ++++++++++-------- .../datasplits/datasets/arrays/zarr_array.py | 16 +- dacapo/predict.py | 12 +- dacapo/validate.py | 8 +- 6 files changed, 126 insertions(+), 94 deletions(-) create mode 100644 dacapo/blockwise/global_vars.py diff --git a/dacapo/blockwise/__init__.py b/dacapo/blockwise/__init__.py index 6027a9115..aa198e0d0 100644 --- a/dacapo/blockwise/__init__.py +++ b/dacapo/blockwise/__init__.py @@ -1,2 +1,3 @@ from .blockwise_task import DaCapoBlockwiseTask from .scheduler import run_blockwise, segment_blockwise +from . import global_vars diff --git a/dacapo/blockwise/global_vars.py b/dacapo/blockwise/global_vars.py new file mode 100644 index 000000000..4d3771721 --- /dev/null +++ b/dacapo/blockwise/global_vars.py @@ -0,0 +1,2 @@ +current_run = None + diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 9b5cdbf33..787787034 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -17,6 +17,7 @@ import numpy as np import click +from dacapo.blockwise import global_vars import logging @@ -27,6 +28,14 @@ path = __file__ +def is_global_run_set(run_name) -> bool: + found = global_vars.current_run is not None + if found: + found = global_vars.current_run.name == run_name + if not found: + logger.error(f"Found global run {global_vars.current_run.name} but looking for {run_name}") + return found + @click.group() @click.option( "--log-level", @@ -70,7 +79,7 @@ def cli(log_level): ) @click.option("-od", "--output_dataset", required=True, type=str) def start_worker( - run_name: str | Run, + run_name: str, iteration: int | None, input_container: Path | str, input_dataset: str, @@ -90,7 +99,7 @@ def start_worker( def start_worker_fn( - run_name: str | Run, + run_name: str, iteration: int | None, input_container: Path | str, input_dataset: str, @@ -109,93 +118,95 @@ def start_worker_fn( output_container (Path | str): The output container. output_dataset (str): The output dataset. """ - compute_context = create_compute_context() - device = compute_context.device + def io_loop(): + daisy_client = daisy.Client() - # retrieving run - if isinstance(run_name, Run): - run = run_name - run_name = run.name - else: - config_store = create_config_store() - run_config = config_store.retrieve_run_config(run_name) - run = Run(run_config) + compute_context = create_compute_context() + device = compute_context.device + + if is_global_run_set(run_name): + logger.warning("Using global run variable") + run = global_vars.current_run + else: + logger.warning("initiating local run in predict_worker") + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + + if iteration is not None and compute_context.distribute_workers: + # create weights store + weights_store = create_weights_store() + + # load weights + run.model.load_state_dict( + weights_store.retrieve_weights(run_name, iteration).model + ) - if iteration is not None and compute_context.distribute_workers: - # create weights store - weights_store = create_weights_store() + # get arrays + input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) + raw_array = ZarrArray.open_from_array_identifier(input_array_identifier) - # load weights - run.model.load_state_dict( - weights_store.retrieve_weights(run_name, iteration).model + output_array_identifier = LocalArrayIdentifier( + Path(output_container), output_dataset + ) + output_array = ZarrArray.open_from_array_identifier(output_array_identifier) + + # set benchmark flag to True for performance + torch.backends.cudnn.benchmark = True + + # get the model's input and output size + model = run.model.eval() + # .to(device) + input_voxel_size = Coordinate(raw_array.voxel_size) + output_voxel_size = model.scale(input_voxel_size) + input_shape = Coordinate(model.eval_input_shape) + input_size = input_voxel_size * input_shape + output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] + + print(f"Predicting with input size {input_size}, output size {output_size}") + + # create gunpowder keys + + raw = gp.ArrayKey("RAW") + prediction = gp.ArrayKey("PREDICTION") + + # assemble prediction pipeline + + # prepare data source + pipeline = DaCapoArraySource(raw_array, raw) + # raw: (c, d, h, w) + pipeline += gp.Pad(raw, None) + # raw: (c, d, h, w) + pipeline += gp.Unsqueeze([raw]) + # raw: (1, c, d, h, w) + + pipeline += gp.Normalize(raw) + + # predict + # model.eval() + pipeline += gp_torch.Predict( + model=model, + inputs={"x": raw}, + outputs={0: prediction}, + array_specs={ + prediction: gp.ArraySpec( + voxel_size=output_voxel_size, + dtype=np.float32, # assumes network output is float32 + ) + }, + spawn_subprocess=False, + device=str(device), ) - # get arrays - input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) - raw_array = ZarrArray.open_from_array_identifier(input_array_identifier) - - output_array_identifier = LocalArrayIdentifier( - Path(output_container), output_dataset - ) - output_array = ZarrArray.open_from_array_identifier(output_array_identifier) - - # set benchmark flag to True for performance - torch.backends.cudnn.benchmark = True - - # get the model's input and output size - model = run.model.eval().to(device) - input_voxel_size = Coordinate(raw_array.voxel_size) - output_voxel_size = model.scale(input_voxel_size) - input_shape = Coordinate(model.eval_input_shape) - input_size = input_voxel_size * input_shape - output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] - - print(f"Predicting with input size {input_size}, output size {output_size}") - - # create gunpowder keys - - raw = gp.ArrayKey("RAW") - prediction = gp.ArrayKey("PREDICTION") - - # assemble prediction pipeline - - # prepare data source - pipeline = DaCapoArraySource(raw_array, raw) - # raw: (c, d, h, w) - pipeline += gp.Pad(raw, None) - # raw: (c, d, h, w) - pipeline += gp.Unsqueeze([raw]) - # raw: (1, c, d, h, w) - - pipeline += gp.Normalize(raw) - - # predict - # model.eval() - pipeline += gp_torch.Predict( - model=model, - inputs={"x": raw}, - outputs={0: prediction}, - array_specs={ - prediction: gp.ArraySpec( - voxel_size=output_voxel_size, - dtype=np.float32, # assumes network output is float32 - ) - }, - spawn_subprocess=False, - device=str(device), - ) - - # make reference batch request - request = gp.BatchRequest() - request.add(raw, input_size, voxel_size=input_voxel_size) - request.add( - prediction, - output_size, - voxel_size=output_voxel_size, - ) + # make reference batch request + request = gp.BatchRequest() + request.add(raw, input_size, voxel_size=input_voxel_size) + request.add( + prediction, + output_size, + voxel_size=output_voxel_size, + ) - def io_loop(): - daisy_client = daisy.Client() while True: with daisy_client.acquire_block() as block: @@ -231,7 +242,7 @@ def io_loop(): def spawn_worker( - run_name: str | Run, + run_name: str, iteration: int | None, input_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", @@ -248,6 +259,8 @@ def spawn_worker( Callable: The function to run the worker. """ compute_context = create_compute_context() + + if not compute_context.distribute_workers: return start_worker_fn( run_name=run_name, diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index f61bf0cd4..30c6ac693 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -369,11 +369,17 @@ def data(self) -> Any: """ file_name = str(self.file_name) # Zarr library does not detect the store for N5 datasets - if file_name.endswith(".n5"): - zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode) - else: - zarr_container = zarr.open(str(file_name), mode=self.mode) - return zarr_container[self.dataset] + try: + if file_name.endswith(".n5"): + zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode) + else: + zarr_container = zarr.open(str(file_name), mode=self.mode) + return zarr_container[self.dataset] + except Exception as e: + logger.error( + f"Could not open dataset {self.dataset} in file {file_name} in mode {self.mode}" + ) + raise e def __getitem__(self, roi: Roi) -> np.ndarray: """ diff --git a/dacapo/predict.py b/dacapo/predict.py index 09ac848cc..f79f26a1e 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -1,5 +1,5 @@ from upath import UPath as Path - +from dacapo.blockwise import global_vars from dacapo.blockwise import run_blockwise import dacapo.blockwise from dacapo.experiments import Run @@ -24,7 +24,7 @@ def predict( input_dataset: str, output_path: LocalArrayIdentifier | Path | str, output_roi: Optional[Roi | str] = None, - num_workers: int = 12, + num_workers: int = 1, output_dtype: np.dtype | str = np.uint8, # type: ignore overwrite: bool = True, ): @@ -136,10 +136,13 @@ def predict( write_size=output_size, ) + global_vars.current_run = run + + # run blockwise prediction worker_file = str(Path(Path(dacapo.blockwise.__file__).parent, "predict_worker.py")) print("Running blockwise prediction with worker_file: ", worker_file) - run_blockwise( + success = run_blockwise( worker_file=worker_file, total_roi=_input_roi, read_roi=Roi((0, 0, 0), input_size), @@ -148,9 +151,10 @@ def predict( max_retries=2, # TODO: make this an option timeout=None, # TODO: make this an option ###### - run_name=run, + run_name=run.name, iteration=iteration, input_array_identifier=input_array_identifier, output_array_identifier=output_array_identifier, ) print("Done predicting.") + return success diff --git a/dacapo/validate.py b/dacapo/validate.py index 398308df2..0da9dfa30 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -186,7 +186,7 @@ def validate( run.name, iteration, validation_dataset.name ) compute_context = create_compute_context() - predict( + sucess = predict( run, iteration if compute_context.distribute_workers else None, input_container=input_raw_array_identifier.container, @@ -198,6 +198,12 @@ def validate( overwrite=overwrite, ) + if not sucess: + logger.error( + f"Could not predict run {run.name} on dataset {validation_dataset.name}." + ) + continue + print(f"Predicted on dataset {validation_dataset.name}") post_processor.set_prediction(prediction_array_identifier) From 9ea893c2bf896c5794bff48af466b01b4f00f666 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 29 Jul 2024 13:17:58 -0400 Subject: [PATCH 03/14] black format --- dacapo/blockwise/global_vars.py | 1 - dacapo/blockwise/predict_worker.py | 12 ++++++++---- dacapo/experiments/datasplits/datasplit_generator.py | 2 +- dacapo/predict.py | 1 - 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/dacapo/blockwise/global_vars.py b/dacapo/blockwise/global_vars.py index 4d3771721..0c804e3ff 100644 --- a/dacapo/blockwise/global_vars.py +++ b/dacapo/blockwise/global_vars.py @@ -1,2 +1 @@ current_run = None - diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 787787034..867c9554b 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -33,9 +33,12 @@ def is_global_run_set(run_name) -> bool: if found: found = global_vars.current_run.name == run_name if not found: - logger.error(f"Found global run {global_vars.current_run.name} but looking for {run_name}") + logger.error( + f"Found global run {global_vars.current_run.name} but looking for {run_name}" + ) return found + @click.group() @click.option( "--log-level", @@ -118,6 +121,7 @@ def start_worker_fn( output_container (Path | str): The output container. output_dataset (str): The output dataset. """ + def io_loop(): daisy_client = daisy.Client() @@ -143,7 +147,9 @@ def io_loop(): ) # get arrays - input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) + input_array_identifier = LocalArrayIdentifier( + Path(input_container), input_dataset + ) raw_array = ZarrArray.open_from_array_identifier(input_array_identifier) output_array_identifier = LocalArrayIdentifier( @@ -207,7 +213,6 @@ def io_loop(): voxel_size=output_voxel_size, ) - while True: with daisy_client.acquire_block() as block: if block is None: @@ -260,7 +265,6 @@ def spawn_worker( """ compute_context = create_compute_context() - if not compute_context.distribute_workers: return start_worker_fn( run_name=run_name, diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index ce229deee..d3a6cb7d6 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -757,7 +757,7 @@ def __generate_semantic_seg_datasplit(self): mask_config=mask_config, ) ) - + return TrainValidateDataSplitConfig( name=f"{self.name}_{self.segmentation_type}_{classes}_{self.output_resolution[0]}nm", train_configs=train_dataset_configs, diff --git a/dacapo/predict.py b/dacapo/predict.py index f79f26a1e..f28e97663 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -138,7 +138,6 @@ def predict( global_vars.current_run = run - # run blockwise prediction worker_file = str(Path(Path(dacapo.blockwise.__file__).parent, "predict_worker.py")) print("Running blockwise prediction with worker_file: ", worker_file) From 0cf9e2557a6831c7b2dd23580bf998b9e35cd787 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 30 Jul 2024 11:37:02 -0400 Subject: [PATCH 04/14] fix plot hack --- .../threshold_post_processor.py | 7 ++- dacapo/plot.py | 49 +++++++++---------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index c0e10418c..f99c64d3a 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -68,7 +68,7 @@ def process( self, parameters: "ThresholdPostProcessorParameters", # type: ignore[override] output_array_identifier: "LocalArrayIdentifier", - num_workers: int = 16, + num_workers: int = 12, block_size: Coordinate = Coordinate((256, 256, 256)), ) -> ZarrArray: """ @@ -122,7 +122,7 @@ def process( read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :]) # run blockwise post-processing - run_blockwise( + sucess = run_blockwise( worker_file=str( Path(Path(dacapo.blockwise.__file__).parent, "threshold_worker.py") ), @@ -138,4 +138,7 @@ def process( threshold=parameters.threshold, ) + if not sucess: + raise RuntimeError("Blockwise post-processing failed.") + return output_array diff --git a/dacapo/plot.py b/dacapo/plot.py index e86f697b3..3b12a52d8 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -3,9 +3,8 @@ from dacapo.store.create_store import create_config_store, create_stats_store from dacapo.experiments.run import Run -from bokeh.palettes import Category20 as palette -import bokeh.layouts -import bokeh.plotting +from dacapo.plotting.plot_handler import PlotHandler, RunInfo +from bokeh.plotting.matplot_plot_handler import MatplotPlotHandler import numpy as np from collections import namedtuple @@ -104,7 +103,7 @@ def get_runs_info( run_config.trainer_config.name, run_config.datasplit_config.name, ( - stats_store.retrieve_training_stats(run_config_name, subsample=True) + stats_store.retrieve_training_stats(run_config_name) if plot_loss else None ), @@ -159,7 +158,7 @@ def plot_runs( tools="pan, wheel_zoom, reset, save, hover", x_axis_label="iterations", tooltips=loss_tooltips, - plot_width=2048, + # plot_width=2048, ) loss_figure.background_fill_color = "#efefef" @@ -202,7 +201,7 @@ def plot_runs( tools="pan, wheel_zoom, reset, save, hover", x_axis_label="iterations", tooltips=validation_tooltips, - plot_width=2048, + # plot_width=2048, ) validation_figure.background_fill_color = "#efefef" validation_figures[dataset.name] = validation_figure @@ -226,7 +225,7 @@ def plot_runs( x_axis_label="model size", y_axis_label="best validation", tooltips=summary_tooltips, - plot_width=2048, + # plot_width=2048, ) summary_figure.background_fill_color = "#efefef" @@ -297,24 +296,24 @@ def plot_runs( "run": [run.name] * len(x), } # TODO: get_best: higher_is_better is not true for all scores - best_parameters, best_scores = run.validation_scores.get_best( - dataset_data, dim="parameters" - ) - - source_dict.update( - { - name: np.array( - [ - getattr(best_parameter, name) - for best_parameter in best_parameters.values - ] - ) - for name in run.validation_scores.parameter_names - } - ) - source_dict.update( - {run.validation_score_name: np.array(best_scores.values)} - ) + # best_parameters, best_scores = run.validation_scores.get_best( + # dataset_data, dim="parameters" + # ) + + # source_dict.update( + # { + # name: np.array( + # [ + # getattr(best_parameter, name) + # for best_parameter in best_parameters.values + # ] + # ) + # for name in run.validation_scores.parameter_names + # } + # ) + # source_dict.update( + # {run.validation_score_name: np.array(best_scores.values)} + # ) source = bokeh.plotting.ColumnDataSource(source_dict) validation_figures[dataset.name].line( From 276d48138f60878bc6eac4c0be61ead204ce9d6d Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 30 Jul 2024 11:39:37 -0400 Subject: [PATCH 05/14] fix import --- dacapo/plot.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dacapo/plot.py b/dacapo/plot.py index 3b12a52d8..9829dfd60 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -3,8 +3,9 @@ from dacapo.store.create_store import create_config_store, create_stats_store from dacapo.experiments.run import Run -from dacapo.plotting.plot_handler import PlotHandler, RunInfo -from bokeh.plotting.matplot_plot_handler import MatplotPlotHandler +from bokeh.palettes import Category20 as palette +import bokeh.layouts +import bokeh.plotting import numpy as np from collections import namedtuple From 24c963aec25b80eb9704a7a07e9dd4f8fa145dbe Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 31 Jul 2024 16:41:57 -0400 Subject: [PATCH 06/14] matplotlib plot --- dacapo/plot.py | 90 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/dacapo/plot.py b/dacapo/plot.py index 9829dfd60..8d9d3c2dd 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -7,11 +7,18 @@ import bokeh.layouts import bokeh.plotting import numpy as np +from tqdm import tqdm from collections import namedtuple import itertools from typing import List +import matplotlib.pyplot as plt + +import os + + + RunInfo = namedtuple( "RunInfo", [ @@ -117,7 +124,7 @@ def get_runs_info( return runs -def plot_runs( +def bokeh_plot_runs( run_config_base_names, smooth=100, validation_scores=None, @@ -384,3 +391,84 @@ def plot_runs( else: bokeh.plotting.output_file("performance_plots.html") bokeh.plotting.save(plot) + + +def plot_runs( + run_config_base_names, + smooth=100, + validation_scores=None, + higher_is_betters=None, + plot_losses=None, +): + """ + Plot runs. + Args: + run_config_base_names: Names of run configs to plot + smooth: Smoothing factor + validation_scores: Validation scores to plot + higher_is_betters: Whether higher is better + plot_losses: Whether to plot losses + Returns: + None + """ + print("PLOTTING RUNS") + runs = get_runs_info(run_config_base_names, validation_scores, plot_losses) + print("GOT RUNS INFO") + + colors = itertools.cycle(plt.cm.tab20.colors) + include_validation_figure = False + include_loss_figure = False + + fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(15, 10)) + loss_ax = axes[0] + validation_ax = axes[1] + + for run, color in zip(runs, colors): + name = run.name + + if run.plot_loss: + iterations = [stat.iteration for stat in run.training_stats.iteration_stats] + losses = [stat.loss for stat in run.training_stats.iteration_stats] + + print(f"Run {run.name} has {len(losses)} iterations") + + if run.plot_loss: + include_loss_figure = True + smooth = int(np.maximum(len(iterations) / 2500, 1)) + print(f"smoothing: {smooth}") + x, _ = smooth_values(iterations, smooth, stride=smooth) + y, s = smooth_values(losses, smooth, stride=smooth) + print(x, y) + print(f"plotting {(len(x), len(y))} points") + loss_ax.plot(x, y, label=name, color=color) + print("LOSS PLOTTED") + + if run.validation_score_name and run.validation_scores.validated_until() > 0: + validation_score_data = run.validation_scores.to_xarray().sel( + criteria=run.validation_score_name + ) + colors_val = itertools.cycle(plt.cm.tab20.colors) + for dataset,color_v in zip(run.validation_scores.datasets,colors_val): + dataset_data = validation_score_data.sel(datasets=dataset) + include_validation_figure = True + x = [score.iteration for score in run.validation_scores.scores] + cc = next(colors_val) + for i in range(dataset_data.data.shape[1]): + current_name = f"{i}_{dataset.name}_{name}_{run.validation_score_name}" + validation_ax.plot(x, dataset_data.data[:,i] , label=current_name, color=cc, alpha=0.5+0.2*i) + print("VALIDATION PLOTTED") + + if include_loss_figure: + loss_ax.set_title("Training") + loss_ax.set_xlabel("Iterations") + loss_ax.set_ylabel("Loss") + loss_ax.legend() + + if include_validation_figure: + validation_ax.set_title("Validation") + validation_ax.set_xlabel("Iterations") + validation_ax.set_ylabel("Validation Score") + validation_ax.legend() + + plt.tight_layout() + plt.show() \ No newline at end of file From 7f534cfe22debd002185b97acbff49260ffd0931 Mon Sep 17 00:00:00 2001 From: mzouink Date: Wed, 31 Jul 2024 20:42:31 +0000 Subject: [PATCH 07/14] :art: Format Python code with psf/black --- dacapo/plot.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/dacapo/plot.py b/dacapo/plot.py index 8d9d3c2dd..d5bfe1d28 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -17,8 +17,7 @@ import os - - + RunInfo = namedtuple( "RunInfo", [ @@ -448,14 +447,22 @@ def plot_runs( criteria=run.validation_score_name ) colors_val = itertools.cycle(plt.cm.tab20.colors) - for dataset,color_v in zip(run.validation_scores.datasets,colors_val): + for dataset, color_v in zip(run.validation_scores.datasets, colors_val): dataset_data = validation_score_data.sel(datasets=dataset) include_validation_figure = True x = [score.iteration for score in run.validation_scores.scores] cc = next(colors_val) for i in range(dataset_data.data.shape[1]): - current_name = f"{i}_{dataset.name}_{name}_{run.validation_score_name}" - validation_ax.plot(x, dataset_data.data[:,i] , label=current_name, color=cc, alpha=0.5+0.2*i) + current_name = ( + f"{i}_{dataset.name}_{name}_{run.validation_score_name}" + ) + validation_ax.plot( + x, + dataset_data.data[:, i], + label=current_name, + color=cc, + alpha=0.5 + 0.2 * i, + ) print("VALIDATION PLOTTED") if include_loss_figure: @@ -471,4 +478,4 @@ def plot_runs( validation_ax.legend() plt.tight_layout() - plt.show() \ No newline at end of file + plt.show() From a64c80fceb5f1f87ccb87f872368a25a6f45399a Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 5 Aug 2024 13:23:40 -0400 Subject: [PATCH 08/14] black format --- dacapo/experiments/datasplits/datasplit_generator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index 43f2f3606..d3a6cb7d6 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -758,7 +758,6 @@ def __generate_semantic_seg_datasplit(self): ) ) - return TrainValidateDataSplitConfig( name=f"{self.name}_{self.segmentation_type}_{classes}_{self.output_resolution[0]}nm", train_configs=train_dataset_configs, From e2e9da7e0eeea8ef936f315a718ce915cf0d2dc4 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 5 Aug 2024 14:22:15 -0400 Subject: [PATCH 09/14] mypy fix --- dacapo/blockwise/global_vars.py | 4 +++- dacapo/blockwise/predict_worker.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/dacapo/blockwise/global_vars.py b/dacapo/blockwise/global_vars.py index 0c804e3ff..170f13d68 100644 --- a/dacapo/blockwise/global_vars.py +++ b/dacapo/blockwise/global_vars.py @@ -1 +1,3 @@ -current_run = None +from dacapo.experiments import Run +from typing import Optional +current_run: Optional[Run] = None diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 867c9554b..4485cf80a 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -29,14 +29,17 @@ def is_global_run_set(run_name) -> bool: - found = global_vars.current_run is not None - if found: - found = global_vars.current_run.name == run_name - if not found: + if global_vars.current_run is not None: + if global_vars.current_run.name == run_name: + return True + else: logger.error( f"Found global run {global_vars.current_run.name} but looking for {run_name}" ) - return found + return False + else: + logger.error("No global run is set.") + return False @click.group() From 4919c91d64d56e69652104e09ae96b9764e3be68 Mon Sep 17 00:00:00 2001 From: mzouink Date: Mon, 5 Aug 2024 18:22:47 +0000 Subject: [PATCH 10/14] :art: Format Python code with psf/black --- dacapo/blockwise/global_vars.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dacapo/blockwise/global_vars.py b/dacapo/blockwise/global_vars.py index 170f13d68..1d5410daa 100644 --- a/dacapo/blockwise/global_vars.py +++ b/dacapo/blockwise/global_vars.py @@ -1,3 +1,4 @@ from dacapo.experiments import Run from typing import Optional + current_run: Optional[Run] = None From 8be4941336bb253027a0d9344b3471b4f4f7022d Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 5 Aug 2024 15:57:58 -0400 Subject: [PATCH 11/14] fix circular import --- dacapo/blockwise/global_vars.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dacapo/blockwise/global_vars.py b/dacapo/blockwise/global_vars.py index 170f13d68..0c804e3ff 100644 --- a/dacapo/blockwise/global_vars.py +++ b/dacapo/blockwise/global_vars.py @@ -1,3 +1 @@ -from dacapo.experiments import Run -from typing import Optional -current_run: Optional[Run] = None +current_run = None From cab116974e40e341a3b4eca55c8b3be02b2dddf1 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 5 Aug 2024 16:20:03 -0400 Subject: [PATCH 12/14] fix validate --- tests/operations/test_validate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index 0ebdd5e03..c8f65a05b 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -56,13 +56,13 @@ def test_validate( # test validating iterations for which we know there are weights weights_store.store_weights(run, 0) - validate(run_config.name, 0, num_workers=4) + validate(run.name, 0, num_workers=4) # weights_store.store_weights(run, 1) # validate(run_config.name, 1, num_workers=4) # test validating weights that don't exist with pytest.raises(FileNotFoundError): - validate(run_config.name, 2, num_workers=4) + validate(run.name, 2, num_workers=4) if debug: os.chdir(old_path) From 5e59778968b8a7bec491950ad414bb1686f52853 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 5 Aug 2024 16:45:47 -0400 Subject: [PATCH 13/14] fix tests --- tests/operations/test_validate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index c8f65a05b..1fc2a6e8b 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -5,7 +5,7 @@ from dacapo.experiments import Run from dacapo.store.create_store import create_config_store, create_weights_store -from dacapo import validate +from dacapo import validate_run import pytest from pytest_lazy_fixtures import lf @@ -56,13 +56,13 @@ def test_validate( # test validating iterations for which we know there are weights weights_store.store_weights(run, 0) - validate(run.name, 0, num_workers=4) + validate_run(run_config.name, 0, num_workers=4) # weights_store.store_weights(run, 1) - # validate(run_config.name, 1, num_workers=4) + # validate_run(run_config.name, 1, num_workers=4) # test validating weights that don't exist with pytest.raises(FileNotFoundError): - validate(run.name, 2, num_workers=4) + validate_run(run_config.name, 2, num_workers=4) if debug: os.chdir(old_path) From 5a6440b1438942c961a085e740599cfc991eb847 Mon Sep 17 00:00:00 2001 From: mzouink Date: Mon, 5 Aug 2024 20:46:20 +0000 Subject: [PATCH 14/14] :art: Format Python code with psf/black --- dacapo/blockwise/global_vars.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dacapo/blockwise/global_vars.py b/dacapo/blockwise/global_vars.py index 054b0ce37..0c804e3ff 100644 --- a/dacapo/blockwise/global_vars.py +++ b/dacapo/blockwise/global_vars.py @@ -1,2 +1 @@ - current_run = None