Skip to content

Commit

Permalink
all
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Apr 15, 2024
1 parent 3d3cfdf commit 054367d
Show file tree
Hide file tree
Showing 15 changed files with 81 additions and 47 deletions.
10 changes: 8 additions & 2 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import lazy_property
import numpy as np
import zarr
from zarr.n5 import N5FSStore

from collections import OrderedDict
import logging
from pathlib import Path
from upath import UPath as Path
import json
from typing import Dict, Tuple, Any, Optional, List

Expand Down Expand Up @@ -358,7 +359,12 @@ def data(self) -> Any:
Notes:
This method is used to return the data of the array.
"""
zarr_container = zarr.open(str(self.file_name))
file_name = str(self.file_name)
# Zarr library does not detect the store for N5 datasets
if file_name.endswith(".n5"):
zarr_container = zarr.open(N5FSStore(str(file_name)), mode="r")
else:
zarr_container = zarr.open(str(file_name), mode="r")
return zarr_container[self.dataset]

def __getitem__(self, roi: Roi) -> np.ndarray:
Expand Down
17 changes: 10 additions & 7 deletions dacapo/experiments/datasplits/datasets/raw_gt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ def __init__(self, dataset_config):
This method is used to initialize the dataset.
"""
self.name = dataset_config.name
self.raw = dataset_config.raw_config.array_type(dataset_config.raw_config)
self.gt = dataset_config.gt_config.array_type(dataset_config.gt_config)
self.mask = (
dataset_config.mask_config.array_type(dataset_config.mask_config)
if dataset_config.mask_config is not None
else None
)
try:
self.raw = dataset_config.raw_config.array_type(dataset_config.raw_config)
self.gt = dataset_config.gt_config.array_type(dataset_config.gt_config)
self.mask = (
dataset_config.mask_config.array_type(dataset_config.mask_config)
if dataset_config.mask_config is not None
else None
)
except Exception as e:
raise Exception(f"Error loading arrays for dataset {self.name}: {e} \n {dataset_config}")
self.sample_points = dataset_config.sample_points
self.weight = dataset_config.weight
14 changes: 9 additions & 5 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dacapo.experiments.tasks import TaskConfig
from pathlib import Path
from upath import UPath as Path
from typing import List
from enum import Enum, EnumMeta
from funlib.geometry import Coordinate
from typing import Union, Optional

import zarr
from zarr.n5 import N5FSStore
from dacapo.experiments.datasplits.datasets.arrays import (
ZarrArrayConfig,
ZarrArray,
Expand All @@ -21,7 +22,7 @@
logger = logging.getLogger(__name__)


def is_zarr_group(file_name: str, dataset: str):
def is_zarr_group(file_name: Path, dataset: str):
"""
Check if the dataset is a Zarr group. If the dataset is a Zarr group, it will return True, otherwise False.
Expand All @@ -40,7 +41,10 @@ def is_zarr_group(file_name: str, dataset: str):
Notes:
This function is used to check if the dataset is a Zarr group.
"""
zarr_file = zarr.open(str(file_name))
if file_name.suffix == ".n5":
zarr_file = zarr.open(N5FSStore(str(file_name)), mode="r")
else:
zarr_file = zarr.open(str(file_name), mode="r")
return isinstance(zarr_file[dataset], zarr.hierarchy.Group)


Expand Down Expand Up @@ -762,7 +766,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
# f"Processing raw_container:{raw_container} raw_dataset:{raw_dataset} gt_path:{gt_path} gt_dataset:{gt_dataset}"
# )

if is_zarr_group(str(raw_container), raw_dataset):
if is_zarr_group(raw_container, raw_dataset):
raw_config = get_right_resolution_array_config(
raw_container, raw_dataset, self.input_resolution, "raw"
)
Expand All @@ -789,7 +793,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
raise FileNotFoundError(
f"GT path {gt_path/current_class_dataset} does not exist."
)
if is_zarr_group(str(gt_path), current_class_dataset):
if is_zarr_group(gt_path, current_class_dataset):
gt_config = get_right_resolution_array_config(
gt_path, current_class_dataset, self.output_resolution, "gt"
)
Expand Down
26 changes: 22 additions & 4 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(self, run_config, load_starter_model: bool = True):
"""
self.name = run_config.name
self._config = run_config
self.train_until = run_config.num_iterations
self.validation_interval = run_config.validation_interval

Expand All @@ -106,17 +107,18 @@ def __init__(self, run_config, load_starter_model: bool = True):
self.task = task_type(run_config.task_config)
self.architecture = architecture_type(run_config.architecture_config)
self.trainer = trainer_type(run_config.trainer_config)
self.datasplit = datasplit_type(run_config.datasplit_config)

# lazy load datasplit
self._datasplit = None


# combined pieces
self.model = self.task.create_model(self.architecture)
self.optimizer = self.trainer.create_optimizer(self.model)

# tracking
self.training_stats = TrainingStats()
self.validation_scores = ValidationScores(
self.task.parameters, self.datasplit.validate, self.task.evaluation_scores
)
self._validation_scores = None

if not load_starter_model:
self.start = None
Expand All @@ -142,6 +144,22 @@ def __init__(self, run_config, load_starter_model: bool = True):

self.start.initialize_weights(self.model, new_head=new_head)

@property
def datasplit(self):
if self._datasplit is None:
self._datasplit = self._config.datasplit_config.datasplit_type(
self._config.datasplit_config
)
return self._datasplit

@property
def validation_scores(self):
if self._validation_scores is None:
self._validation_scores = ValidationScores(
self.task.parameters, self.datasplit.validate, self.task.evaluation_scores
)
return self._validation_scores

@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def enumerate_parameters(self) -> Iterable["ThresholdPostProcessorParameters"]:
Note:
This method should return a generator of instances of ``ThresholdPostProcessorParameters``.
"""
for i, threshold in enumerate([100, 127, 150]):
for i, threshold in enumerate([127]):
yield ThresholdPostProcessorParameters(id=i, threshold=threshold)

def set_prediction(self, prediction_array_identifier):
Expand Down
4 changes: 3 additions & 1 deletion dacapo/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def validate(

config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
logger.warning(f"run_config: {run_config}")
run = Run(run_config)

# read in previous training/validation stats
Expand Down Expand Up @@ -97,7 +98,8 @@ def validate_run(
or len(run.datasplit.validate) == 0
or run.datasplit.validate[0].gt is None
):
raise ValueError(f"Cannot validate run {run.name} at iteration {iteration}.")
run.datasplit.validate = run.datasplit.train[-2:]
# raise ValueError(f"Cannot validate run {run.name} at iteration {iteration} run.datasplit: {run.datasplit.validate}")

# get array and weight store
array_store = create_array_store()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dependencies = [
"click",
"pyyaml",
"scipy",
"universal_pathlib",
]

# extras
Expand Down
8 changes: 4 additions & 4 deletions tests/components/test_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from dacapo.store.create_store import create_config_store

import pytest
from pytest_lazyfixture import lazy_fixture
from pytest_lazy_fixtures import lf


@pytest.mark.parametrize(
"array_config",
[
lazy_fixture("cellmap_array"),
lazy_fixture("zarr_array"),
lazy_fixture("dummy_array"),
lf("cellmap_array"),
lf("zarr_array"),
lf("dummy_array"),
],
)
def test_array_api(options, array_config):
Expand Down
8 changes: 4 additions & 4 deletions tests/components/test_gp_arraysource.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import gunpowder as gp

import pytest
from pytest_lazyfixture import lazy_fixture
from pytest_lazy_fixtures import lf


@pytest.mark.parametrize(
"array_config",
[
lazy_fixture("cellmap_array"),
lazy_fixture("zarr_array"),
lazy_fixture("dummy_array"),
lf("cellmap_array"),
lf("zarr_array"),
lf("dummy_array"),
],
)
def test_gp_dacapo_array_source(array_config):
Expand Down
6 changes: 3 additions & 3 deletions tests/components/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from dacapo.store.create_store import create_config_store

import pytest
from pytest_lazyfixture import lazy_fixture
from pytest_lazy_fixtures import lf


@pytest.mark.parametrize(
"trainer_config",
[
lazy_fixture("dummy_trainer"),
lazy_fixture("gunpowder_trainer"),
lf("dummy_trainer"),
lf("gunpowder_trainer"),
],
)
def test_trainer(
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
def dummy_architecture():
yield DummyArchitectureConfig(
name="dummy_architecture", num_in_channels=1, num_out_channels=12
)
)
8 changes: 4 additions & 4 deletions tests/operations/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dacapo import apply

import pytest
from pytest_lazyfixture import lazy_fixture
from pytest_lazy_fixtures import lf

import logging

Expand All @@ -18,9 +18,9 @@
@pytest.mark.parametrize(
"run_config",
[
# lazy_fixture("distance_run"),
lazy_fixture("dummy_run"),
# lazy_fixture("onehot_run"),
lf("distance_run"),
lf("dummy_run"),
lf("onehot_run"),
],
)
def test_apply(options, run_config, zarr_array, tmp_path):
Expand Down
8 changes: 4 additions & 4 deletions tests/operations/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dacapo import predict

import pytest
from pytest_lazyfixture import lazy_fixture
from pytest_lazy_fixtures import lf

import logging

Expand All @@ -18,9 +18,9 @@
@pytest.mark.parametrize(
"run_config",
[
# lazy_fixture("distance_run"),
lazy_fixture("dummy_run"),
# lazy_fixture("onehot_run"),
# lf("distance_run"),
lf("dummy_run"),
# lf("onehot_run"),
],
)
def test_predict(options, run_config, zarr_array, tmp_path):
Expand Down
8 changes: 4 additions & 4 deletions tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dacapo.train import train_run

import pytest
from pytest_lazyfixture import lazy_fixture
from pytest_lazy_fixtures import lf

import logging

Expand All @@ -20,9 +20,9 @@
@pytest.mark.parametrize(
"run_config",
[
lazy_fixture("distance_run"),
lazy_fixture("dummy_run"),
lazy_fixture("onehot_run"),
lf("distance_run"),
lf("dummy_run"),
lf("onehot_run"),
],
)
def test_train(
Expand Down
6 changes: 3 additions & 3 deletions tests/operations/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dacapo import validate

import pytest
from pytest_lazyfixture import lazy_fixture
from pytest_lazy_fixtures import lf

import logging

Expand All @@ -18,8 +18,8 @@
@pytest.mark.parametrize(
"run_config",
[
lazy_fixture("distance_run"),
# lazy_fixture("onehot_run"),
lf("distance_run"),
lf("onehot_run"),
],
)
def test_validate(
Expand Down

0 comments on commit 054367d

Please sign in to comment.