Skip to content

Commit

Permalink
organize
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 13, 2024
1 parent f3d0508 commit 6d87d38
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 50 deletions.
2 changes: 1 addition & 1 deletion tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .db import options
from .architectures import dummy_architecture, unet_architecture
from .architectures import dummy_architecture, unet_architecture,unet_architecture_builder
from .arrays import dummy_array, zarr_array, cellmap_array
from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit
from .evaluators import binary_3_channel_evaluator
Expand Down
43 changes: 43 additions & 0 deletions tests/fixtures/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,46 @@ def unet_architecture():
constant_upsample=True,
padding="valid",
)



def unet_architecture_builder(batch_norm, upsample, use_attention, three_d):
name = "3d_unet" if three_d else "2d_unet"
name = f"{name}_bn" if batch_norm else name
name = f"{name}_up" if upsample else name
name = f"{name}_att" if use_attention else name

if three_d:
return CNNectomeUNetConfig(
name=name,
input_shape=(188, 188, 188),
eval_shape_increase=(72, 72, 72),
fmaps_in=1,
num_fmaps=6,
fmaps_out=6,
fmap_inc_factor=2,
downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
constant_upsample=True,
upsample_factors=[(2, 2, 2)] if upsample else [],
batch_norm=batch_norm,
use_attention=use_attention,
)
else:
return CNNectomeUNetConfig(
name=name,
input_shape=(2, 132, 132),
eval_shape_increase=(8, 32, 32),
fmaps_in=1,
num_fmaps=8,
fmaps_out=8,
fmap_inc_factor=2,
downsample_factors=[(1, 4, 4), (1, 4, 4)],
kernel_size_down=[[(1, 3, 3)] * 2] * 3,
kernel_size_up=[[(1, 3, 3)] * 2] * 2,
constant_upsample=True,
padding="valid",
batch_norm=batch_norm,
use_attention=use_attention,
upsample_factors=[(1, 2, 2)] if upsample else [],
)

79 changes: 30 additions & 49 deletions tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,9 @@

logging.basicConfig(level=logging.INFO)

from dacapo.experiments.architectures import (
DummyArchitectureConfig,
CNNectomeUNetConfig,
)

import pytest


def unet_architecture(batch_norm, upsample, use_attention, three_d):
name = "3d_unet" if three_d else "2d_unet"
name = f"{name}_bn" if batch_norm else name
name = f"{name}_up" if upsample else name
name = f"{name}_att" if use_attention else name

if three_d:
return CNNectomeUNetConfig(
name=name,
input_shape=(188, 188, 188),
eval_shape_increase=(72, 72, 72),
fmaps_in=1,
num_fmaps=6,
fmaps_out=6,
fmap_inc_factor=2,
downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
constant_upsample=True,
upsample_factors=[(2, 2, 2)] if upsample else [],
batch_norm=batch_norm,
use_attention=use_attention,
)
else:
return CNNectomeUNetConfig(
name=name,
input_shape=(2, 132, 132),
eval_shape_increase=(8, 32, 32),
fmaps_in=1,
num_fmaps=8,
fmaps_out=8,
fmap_inc_factor=2,
downsample_factors=[(1, 4, 4), (1, 4, 4)],
kernel_size_down=[[(1, 3, 3)] * 2] * 3,
kernel_size_up=[[(1, 3, 3)] * 2] * 2,
constant_upsample=True,
padding="valid",
batch_norm=batch_norm,
use_attention=use_attention,
upsample_factors=[(1, 2, 2)] if upsample else [],
)


# skip the test for the Apple Paravirtual device
# that does not support Metal 2.0
@pytest.mark.filterwarnings("ignore:.*Metal 2.0.*:UserWarning")
Expand Down Expand Up @@ -115,7 +69,34 @@ def test_train(
@pytest.mark.parametrize("task", [lf("distance_task")])
@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("upsample", [True, False])
@pytest.mark.parametrize("upsample", [False])
@pytest.mark.parametrize("use_attention", [True, False])
@pytest.mark.parametrize("three_d", [True, False])
def test_train_unet(
datasplit, task, trainer, batch_norm, upsample, use_attention, three_d
):
architecture_config = unet_architecture(
batch_norm, upsample, use_attention, three_d
)

run_config = RunConfig(
name=f"{architecture_config.name}_run",
task_config=task,
architecture_config=architecture_config,
trainer_config=trainer,
datasplit_config=datasplit,
repetition=0,
num_iterations=2,
)
run = Run(run_config)
train_run(run)


@pytest.mark.parametrize("datasplit", [lf("six_class_datasplit")])
@pytest.mark.parametrize("task", [lf("distance_task")])
@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("upsample", [False])
@pytest.mark.parametrize("use_attention", [True, False])
@pytest.mark.parametrize("three_d", [True, False])
def test_train_unet(
Expand All @@ -125,7 +106,7 @@ def test_train_unet(
stats_store = create_stats_store()
weights_store = create_weights_store()

architecture_config = unet_architecture(
architecture_config = unet_architecture_builder(
batch_norm, upsample, use_attention, three_d
)

Expand All @@ -136,7 +117,7 @@ def test_train_unet(
trainer_config=trainer,
datasplit_config=datasplit,
repetition=0,
num_iterations=2,
num_iterations=30,
)
try:
store.store_run_config(run_config)
Expand Down

0 comments on commit 6d87d38

Please sign in to comment.