Skip to content

Commit

Permalink
add unet2d/3d validation test
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 14, 2024
1 parent 7b7c956 commit 200ff20
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 3 deletions.
3 changes: 2 additions & 1 deletion tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .architectures import (
dummy_architecture,
unet_architecture,
unet_3d_architecture,
unet_architecture_builder,
)
from .arrays import dummy_array, zarr_array, cellmap_array
Expand All @@ -16,5 +17,5 @@
from .post_processors import argmax, threshold
from .predictors import distance_predictor, onehot_predictor
from .runs import dummy_run, distance_run, onehot_run
from .tasks import dummy_task, distance_task, onehot_task
from .tasks import dummy_task, distance_task, onehot_task, six_onehot_task
from .trainers import dummy_trainer, gunpowder_trainer
16 changes: 16 additions & 0 deletions tests/fixtures/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ def unet_architecture():
)



@pytest.fixture()
def unet_3d_architecture():
yield CNNectomeUNetConfig(
name="tmp_unet_3d_architecture",
input_shape=(188, 188, 188),
eval_shape_increase=(72, 72, 72),
fmaps_in=1,
num_fmaps=6,
fmaps_out=6,
fmap_inc_factor=2,
downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
constant_upsample=True,
)


def unet_architecture_builder(batch_norm, upsample, use_attention, three_d):
name = "3d_unet" if three_d else "2d_unet"
name = f"{name}_bn" if batch_norm else name
Expand Down
7 changes: 7 additions & 0 deletions tests/fixtures/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,10 @@ def onehot_task():
name="one_hot_task",
classes=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"],
)

@pytest.fixture()
def six_onehot_task():
yield OneHotTaskConfig(
name="one_hot_task",
classes=["a", "b", "c", "d", "e", "f"],
)
4 changes: 2 additions & 2 deletions tests/operations/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def test_validate_run(


@pytest.mark.parametrize("datasplit", [lf("six_class_datasplit")])
@pytest.mark.parametrize("task", [lf("distance_task")])
@pytest.mark.parametrize("task", [lf("distance_task"), lf("six_onehot_task")])
@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")])
@pytest.mark.parametrize("architecture", [lf("unet_architecture")])
@pytest.mark.parametrize("architecture", [lf("unet_architecture"), lf("unet_3d_architecture")])
def test_validate_unet(datasplit, task, trainer, architecture):
store = create_config_store()
weights_store = create_weights_store()
Expand Down

0 comments on commit 200ff20

Please sign in to comment.