From 054367d47a6a9465aafe795abf347379342d0369 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 15 Apr 2024 17:01:59 -0400 Subject: [PATCH] all --- .../datasplits/datasets/arrays/zarr_array.py | 10 +++++-- .../datasplits/datasets/raw_gt_dataset.py | 17 +++++++----- .../datasplits/datasplit_generator.py | 14 ++++++---- dacapo/experiments/run.py | 26 ++++++++++++++++--- .../threshold_post_processor.py | 2 +- dacapo/validate.py | 4 ++- pyproject.toml | 1 + tests/components/test_arrays.py | 8 +++--- tests/components/test_gp_arraysource.py | 8 +++--- tests/components/test_trainers.py | 6 ++--- tests/fixtures/architectures.py | 2 +- tests/operations/test_apply.py | 8 +++--- tests/operations/test_predict.py | 8 +++--- tests/operations/test_train.py | 8 +++--- tests/operations/test_validate.py | 6 ++--- 15 files changed, 81 insertions(+), 47 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 4fcdc2d27..de9623cc3 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -9,10 +9,11 @@ import lazy_property import numpy as np import zarr +from zarr.n5 import N5FSStore from collections import OrderedDict import logging -from pathlib import Path +from upath import UPath as Path import json from typing import Dict, Tuple, Any, Optional, List @@ -358,7 +359,12 @@ def data(self) -> Any: Notes: This method is used to return the data of the array. """ - zarr_container = zarr.open(str(self.file_name)) + file_name = str(self.file_name) + # Zarr library does not detect the store for N5 datasets + if file_name.endswith(".n5"): + zarr_container = zarr.open(N5FSStore(str(file_name)), mode="r") + else: + zarr_container = zarr.open(str(file_name), mode="r") return zarr_container[self.dataset] def __getitem__(self, roi: Roi) -> np.ndarray: diff --git a/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py b/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py index 8539e8339..6174b2aaf 100644 --- a/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py +++ b/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py @@ -49,12 +49,15 @@ def __init__(self, dataset_config): This method is used to initialize the dataset. """ self.name = dataset_config.name - self.raw = dataset_config.raw_config.array_type(dataset_config.raw_config) - self.gt = dataset_config.gt_config.array_type(dataset_config.gt_config) - self.mask = ( - dataset_config.mask_config.array_type(dataset_config.mask_config) - if dataset_config.mask_config is not None - else None - ) + try: + self.raw = dataset_config.raw_config.array_type(dataset_config.raw_config) + self.gt = dataset_config.gt_config.array_type(dataset_config.gt_config) + self.mask = ( + dataset_config.mask_config.array_type(dataset_config.mask_config) + if dataset_config.mask_config is not None + else None + ) + except Exception as e: + raise Exception(f"Error loading arrays for dataset {self.name}: {e} \n {dataset_config}") self.sample_points = dataset_config.sample_points self.weight = dataset_config.weight diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index fac37ed7e..5424fd991 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -1,11 +1,12 @@ from dacapo.experiments.tasks import TaskConfig -from pathlib import Path +from upath import UPath as Path from typing import List from enum import Enum, EnumMeta from funlib.geometry import Coordinate from typing import Union, Optional import zarr +from zarr.n5 import N5FSStore from dacapo.experiments.datasplits.datasets.arrays import ( ZarrArrayConfig, ZarrArray, @@ -21,7 +22,7 @@ logger = logging.getLogger(__name__) -def is_zarr_group(file_name: str, dataset: str): +def is_zarr_group(file_name: Path, dataset: str): """ Check if the dataset is a Zarr group. If the dataset is a Zarr group, it will return True, otherwise False. @@ -40,7 +41,10 @@ def is_zarr_group(file_name: str, dataset: str): Notes: This function is used to check if the dataset is a Zarr group. """ - zarr_file = zarr.open(str(file_name)) + if file_name.suffix == ".n5": + zarr_file = zarr.open(N5FSStore(str(file_name)), mode="r") + else: + zarr_file = zarr.open(str(file_name), mode="r") return isinstance(zarr_file[dataset], zarr.hierarchy.Group) @@ -762,7 +766,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): # f"Processing raw_container:{raw_container} raw_dataset:{raw_dataset} gt_path:{gt_path} gt_dataset:{gt_dataset}" # ) - if is_zarr_group(str(raw_container), raw_dataset): + if is_zarr_group(raw_container, raw_dataset): raw_config = get_right_resolution_array_config( raw_container, raw_dataset, self.input_resolution, "raw" ) @@ -789,7 +793,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): raise FileNotFoundError( f"GT path {gt_path/current_class_dataset} does not exist." ) - if is_zarr_group(str(gt_path), current_class_dataset): + if is_zarr_group(gt_path, current_class_dataset): gt_config = get_right_resolution_array_config( gt_path, current_class_dataset, self.output_resolution, "gt" ) diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index a0089e712..e88ac4200 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -93,6 +93,7 @@ def __init__(self, run_config, load_starter_model: bool = True): """ self.name = run_config.name + self._config = run_config self.train_until = run_config.num_iterations self.validation_interval = run_config.validation_interval @@ -106,7 +107,10 @@ def __init__(self, run_config, load_starter_model: bool = True): self.task = task_type(run_config.task_config) self.architecture = architecture_type(run_config.architecture_config) self.trainer = trainer_type(run_config.trainer_config) - self.datasplit = datasplit_type(run_config.datasplit_config) + + # lazy load datasplit + self._datasplit = None + # combined pieces self.model = self.task.create_model(self.architecture) @@ -114,9 +118,7 @@ def __init__(self, run_config, load_starter_model: bool = True): # tracking self.training_stats = TrainingStats() - self.validation_scores = ValidationScores( - self.task.parameters, self.datasplit.validate, self.task.evaluation_scores - ) + self._validation_scores = None if not load_starter_model: self.start = None @@ -142,6 +144,22 @@ def __init__(self, run_config, load_starter_model: bool = True): self.start.initialize_weights(self.model, new_head=new_head) + @property + def datasplit(self): + if self._datasplit is None: + self._datasplit = self._config.datasplit_config.datasplit_type( + self._config.datasplit_config + ) + return self._datasplit + + @property + def validation_scores(self): + if self._validation_scores is None: + self._validation_scores = ValidationScores( + self.task.parameters, self.datasplit.validate, self.task.evaluation_scores + ) + return self._validation_scores + @staticmethod def get_validation_scores(run_config) -> ValidationScores: """ diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index a71d409e8..5f67ece9c 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -43,7 +43,7 @@ def enumerate_parameters(self) -> Iterable["ThresholdPostProcessorParameters"]: Note: This method should return a generator of instances of ``ThresholdPostProcessorParameters``. """ - for i, threshold in enumerate([100, 127, 150]): + for i, threshold in enumerate([127]): yield ThresholdPostProcessorParameters(id=i, threshold=threshold) def set_prediction(self, prediction_array_identifier): diff --git a/dacapo/validate.py b/dacapo/validate.py index 79da393e2..819c20724 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -46,6 +46,7 @@ def validate( config_store = create_config_store() run_config = config_store.retrieve_run_config(run_name) + logger.warning(f"run_config: {run_config}") run = Run(run_config) # read in previous training/validation stats @@ -97,7 +98,8 @@ def validate_run( or len(run.datasplit.validate) == 0 or run.datasplit.validate[0].gt is None ): - raise ValueError(f"Cannot validate run {run.name} at iteration {iteration}.") + run.datasplit.validate = run.datasplit.train[-2:] + # raise ValueError(f"Cannot validate run {run.name} at iteration {iteration} run.datasplit: {run.datasplit.validate}") # get array and weight store array_store = create_array_store() diff --git a/pyproject.toml b/pyproject.toml index fd5f31638..f9acee29c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "click", "pyyaml", "scipy", + "universal_pathlib", ] # extras diff --git a/tests/components/test_arrays.py b/tests/components/test_arrays.py index d62dcb973..d91863ad7 100644 --- a/tests/components/test_arrays.py +++ b/tests/components/test_arrays.py @@ -3,15 +3,15 @@ from dacapo.store.create_store import create_config_store import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf @pytest.mark.parametrize( "array_config", [ - lazy_fixture("cellmap_array"), - lazy_fixture("zarr_array"), - lazy_fixture("dummy_array"), + lf("cellmap_array"), + lf("zarr_array"), + lf("dummy_array"), ], ) def test_array_api(options, array_config): diff --git a/tests/components/test_gp_arraysource.py b/tests/components/test_gp_arraysource.py index 7ff626036..69fee515f 100644 --- a/tests/components/test_gp_arraysource.py +++ b/tests/components/test_gp_arraysource.py @@ -5,15 +5,15 @@ import gunpowder as gp import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf @pytest.mark.parametrize( "array_config", [ - lazy_fixture("cellmap_array"), - lazy_fixture("zarr_array"), - lazy_fixture("dummy_array"), + lf("cellmap_array"), + lf("zarr_array"), + lf("dummy_array"), ], ) def test_gp_dacapo_array_source(array_config): diff --git a/tests/components/test_trainers.py b/tests/components/test_trainers.py index 172a89b75..f3f9b07ac 100644 --- a/tests/components/test_trainers.py +++ b/tests/components/test_trainers.py @@ -3,14 +3,14 @@ from dacapo.store.create_store import create_config_store import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf @pytest.mark.parametrize( "trainer_config", [ - lazy_fixture("dummy_trainer"), - lazy_fixture("gunpowder_trainer"), + lf("dummy_trainer"), + lf("gunpowder_trainer"), ], ) def test_trainer( diff --git a/tests/fixtures/architectures.py b/tests/fixtures/architectures.py index 6980c8f6b..ea6c55481 100644 --- a/tests/fixtures/architectures.py +++ b/tests/fixtures/architectures.py @@ -7,4 +7,4 @@ def dummy_architecture(): yield DummyArchitectureConfig( name="dummy_architecture", num_in_channels=1, num_out_channels=12 - ) + ) \ No newline at end of file diff --git a/tests/operations/test_apply.py b/tests/operations/test_apply.py index 5ce608e1e..f1e3693fe 100644 --- a/tests/operations/test_apply.py +++ b/tests/operations/test_apply.py @@ -8,7 +8,7 @@ from dacapo import apply import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf import logging @@ -18,9 +18,9 @@ @pytest.mark.parametrize( "run_config", [ - # lazy_fixture("distance_run"), - lazy_fixture("dummy_run"), - # lazy_fixture("onehot_run"), + lf("distance_run"), + lf("dummy_run"), + lf("onehot_run"), ], ) def test_apply(options, run_config, zarr_array, tmp_path): diff --git a/tests/operations/test_predict.py b/tests/operations/test_predict.py index cd8f6a6c1..4a4b53478 100644 --- a/tests/operations/test_predict.py +++ b/tests/operations/test_predict.py @@ -8,7 +8,7 @@ from dacapo import predict import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf import logging @@ -18,9 +18,9 @@ @pytest.mark.parametrize( "run_config", [ - # lazy_fixture("distance_run"), - lazy_fixture("dummy_run"), - # lazy_fixture("onehot_run"), + # lf("distance_run"), + lf("dummy_run"), + # lf("onehot_run"), ], ) def test_predict(options, run_config, zarr_array, tmp_path): diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index d36655ea7..a852101be 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -7,7 +7,7 @@ from dacapo.train import train_run import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf import logging @@ -20,9 +20,9 @@ @pytest.mark.parametrize( "run_config", [ - lazy_fixture("distance_run"), - lazy_fixture("dummy_run"), - lazy_fixture("onehot_run"), + lf("distance_run"), + lf("dummy_run"), + lf("onehot_run"), ], ) def test_train( diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index fa2cc6b9a..0ebdd5e03 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -8,7 +8,7 @@ from dacapo import validate import pytest -from pytest_lazyfixture import lazy_fixture +from pytest_lazy_fixtures import lf import logging @@ -18,8 +18,8 @@ @pytest.mark.parametrize( "run_config", [ - lazy_fixture("distance_run"), - # lazy_fixture("onehot_run"), + lf("distance_run"), + lf("onehot_run"), ], ) def test_validate(