Skip to content

Commit

Permalink
Mask shape bug fix (#358)
Browse files Browse the repository at this point in the history
Fix two bugs:
1) The mask was not being processed appropriately for saving into the
snapshot array
2) The prepare_ds method was getting the offset passed in voxels, not
physical units

Both related to the errors seen in
#357
  • Loading branch information
mzouink authored Jan 2, 2025
2 parents e864b33 + ae59ec3 commit e9f255c
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 18 deletions.
4 changes: 3 additions & 1 deletion dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ def iterate(self, num_iterations, model, optimizer, device):
),
}
if mask is not None:
snapshot_arrays["volumes/mask"] = mask
snapshot_arrays["volumes/mask"] = np_to_funlib_array(
mask[0], offset=target.offset, voxel_size=target.voxel_size
)
logger.warning(
f"Saving Snapshot. Iteration: {iteration}, "
f"Loss: {loss.detach().cpu().numpy().item()}!"
Expand Down
2 changes: 1 addition & 1 deletion dacapo/tmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def create_from_identifier(
return prepare_ds(
out_path,
shape=(*list_num_channels, *roi.shape / voxel_size),
offset=roi.offset / voxel_size,
offset=roi.offset,
voxel_size=voxel_size,
axis_names=axis_names,
dtype=dtype,
Expand Down
18 changes: 14 additions & 4 deletions tests/operations/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from funlib.persistence import prepare_ds
from funlib.geometry import Coordinate

from dacapo.experiments.trainers import GunpowderTrainerConfig
from dacapo.experiments.datasplits import SimpleDataSplitConfig
from dacapo.experiments.tasks import (
DistanceTaskConfig,
Expand All @@ -13,6 +14,19 @@
from pathlib import Path


def build_test_train_config(multiprocessing: bool):
"""
Builds the simplest possible trainer given the parameters.
"""
return GunpowderTrainerConfig(
name="test_trainer",
batch_size=1,
learning_rate=0.0001,
num_data_fetchers=1 + multiprocessing,
snapshot_interval=1,
)


def build_test_data_config(
tmpdir: Path, data_dims: int, channels: bool, upsample: bool, task_type: str
):
Expand Down Expand Up @@ -104,9 +118,7 @@ def build_test_architecture_config(
data_dims: int,
architecture_dims: int,
channels: bool,
batch_norm: bool,
upsample: bool,
use_attention: bool,
padding: str,
):
"""
Expand Down Expand Up @@ -160,7 +172,5 @@ def build_test_architecture_config(
kernel_size_up=kernel_size_up,
constant_upsample=True,
upsample_factors=upsample_factors,
batch_norm=batch_norm,
use_attention=use_attention,
padding=padding,
)
36 changes: 24 additions & 12 deletions tests/operations/test_mini.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from ..fixtures import *
from .helpers import (
build_test_train_config,
build_test_data_config,
build_test_task_config,
build_test_architecture_config,
)

from dacapo.store.create_store import create_array_store
from dacapo.experiments import Run
from dacapo.train import train_run
from dacapo.validate import validate_run

import zarr

import pytest
from pytest_lazy_fixtures import lf

from dacapo.experiments.run_config import RunConfig

Expand All @@ -22,34 +25,30 @@
@pytest.mark.parametrize("data_dims", [2, 3])
@pytest.mark.parametrize("channels", [True, False])
@pytest.mark.parametrize("task", ["distance", "onehot", "affs"])
@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")])
@pytest.mark.parametrize("architecture_dims", [2, 3])
@pytest.mark.parametrize("upsample", [True, False])
# @pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("batch_norm", [False])
# @pytest.mark.parametrize("use_attention", [True, False])
@pytest.mark.parametrize("use_attention", [False])
@pytest.mark.parametrize("padding", ["valid", "same"])
@pytest.mark.parametrize("func", ["train", "validate"])
@pytest.mark.parametrize("multiprocessing", [False])
def test_mini(
tmpdir,
data_dims,
channels,
task,
trainer,
architecture_dims,
batch_norm,
upsample,
use_attention,
padding,
func,
multiprocessing,
):
# Invalid configurations:
if data_dims == 2 and architecture_dims == 3:
# cannot train a 3D model on 2D data
# TODO: maybe check that an appropriate warning is raised somewhere
return

trainer_config = build_test_train_config(multiprocessing)

data_config = build_test_data_config(
tmpdir,
data_dims,
Expand All @@ -62,17 +61,15 @@ def test_mini(
data_dims,
architecture_dims,
channels,
batch_norm,
upsample,
use_attention,
padding,
)

run_config = RunConfig(
name=f"test_{func}",
task_config=task_config,
architecture_config=architecture_config,
trainer_config=trainer,
trainer_config=trainer_config,
datasplit_config=data_config,
repetition=0,
num_iterations=1,
Expand All @@ -81,5 +78,20 @@ def test_mini(

if func == "train":
train_run(run)
array_store = create_array_store()
snapshot_container = array_store.snapshot_container(run.name).container
assert snapshot_container.exists()
assert all(
x in zarr.open(snapshot_container)
for x in [
"0/volumes/raw",
"0/volumes/gt",
"0/volumes/target",
"0/volumes/weight",
"0/volumes/prediction",
"0/volumes/gradients",
"0/volumes/mask",
]
)
elif func == "validate":
validate_run(run, 1)

0 comments on commit e9f255c

Please sign in to comment.