Skip to content

Commit

Permalink
add validate test
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 13, 2024
1 parent 6d87d38 commit 2b80cfa
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 10 deletions.
2 changes: 1 addition & 1 deletion tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .db import options
from .architectures import dummy_architecture, unet_architecture,unet_architecture_builder
from .arrays import dummy_array, zarr_array, cellmap_array
from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit
from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit, upsample_six_class_datasplit
from .evaluators import binary_3_channel_evaluator
from .losses import dummy_loss
from .post_processors import argmax, threshold
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def unet_architecture():
name="tmp_unet_architecture",
input_shape=(2, 132, 132),
eval_shape_increase=(8, 32, 32),
fmaps_in=2,
fmaps_in=1,
num_fmaps=8,
fmaps_out=8,
fmap_inc_factor=2,
Expand Down
119 changes: 119 additions & 0 deletions tests/fixtures/datasplits.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,122 @@ def six_class_datasplit(tmp_path):
validate_configs=[crop3],
)
return six_class_distances_datasplit_config




@pytest.fixture()
def upsample_six_class_datasplit(tmp_path):
"""
two crops for training, one for validation. Raw data is normally distributed
around 0 with std 1.
gt is provided as distances. First, gt is generated as a 12 class problem:
gt has 12 classes where class i in [0, 11] is all voxels with raw intensity
between (raw.min() + i(raw.max()-raw.min())/12, raw.min() + (i+1)(raw.max()-raw.min())/12).
Then we pair up classes (i, i+1) for i in [0,2,4,6,8,10], and compute distances to
the nearest voxel in the pair. This leaves us with 6 distance channels.
"""
twelve_class_zarr = zarr.open(tmp_path / "twelve_class.zarr", "w")
crop1_raw = ZarrArrayConfig(
name="crop1_raw",
file_name=tmp_path / "twelve_class.zarr",
dataset=f"volumes/crop1/raw",
)
crop1_gt = ZarrArrayConfig(
name="crop1_gt",
file_name=tmp_path / "twelve_class.zarr",
dataset=f"volumes/crop1/gt",
)
crop1_distances = BinarizeArrayConfig(
"crop1_distances",
source_array_config=crop1_gt,
groupings=[
("a", [0, 1]),
("b", [2, 3]),
("c", [4, 5]),
("d", [6, 7]),
("e", [8, 9]),
("f", [10, 11]),
],
)
crop2_raw = ZarrArrayConfig(
name="crop2_raw",
file_name=tmp_path / "twelve_class.zarr",
dataset=f"volumes/crop2/raw",
)
crop2_gt = ZarrArrayConfig(
name="crop2_gt",
file_name=tmp_path / "twelve_class.zarr",
dataset=f"volumes/crop2/gt",
)
crop2_distances = BinarizeArrayConfig(
"crop2_distances",
source_array_config=crop2_gt,
groupings=[
("a", [0, 1]),
("b", [2, 3]),
("c", [4, 5]),
("d", [6, 7]),
("e", [8, 9]),
("f", [10, 11]),
],
)
crop3_raw = ZarrArrayConfig(
name="crop3_raw",
file_name=tmp_path / "twelve_class.zarr",
dataset=f"volumes/crop3/raw",
)
crop3_gt = ZarrArrayConfig(
name="crop3_gt",
file_name=tmp_path / "twelve_class.zarr",
dataset=f"volumes/crop3/gt",
)
crop3_distances = BinarizeArrayConfig(
"crop3_distances",
source_array_config=crop3_gt,
groupings=[
("a", [0, 1]),
("b", [2, 3]),
("c", [4, 5]),
("d", [6, 7]),
("e", [8, 9]),
("f", [10, 11]),
],
)
for raw, gt in zip(
[crop1_raw, crop2_raw, crop3_raw], [crop1_gt, crop2_gt, crop3_gt]
):
raw_dataset = twelve_class_zarr.create_dataset(
raw.dataset, shape=(40, 20, 20), dtype=np.float32
)
gt_dataset = twelve_class_zarr.create_dataset(
gt.dataset, shape=(40, 20, 20), dtype=np.uint8
)
random_data = np.random.rand(40, 20, 20)
# as intensities increase so does the class
for i in list(np.linspace(random_data.min(), random_data.max(), 13))[1:]:
gt_dataset[:] += random_data > i
raw_dataset[:] = random_data
raw_dataset.attrs["offset"] = (0, 0, 0)
raw_dataset.attrs["voxel_size"] = (4, 4, 4)
raw_dataset.attrs["axis_names"] = ("z", "y", "x")
gt_dataset.attrs["offset"] = (0, 0, 0)
gt_dataset.attrs["voxel_size"] = (2, 2, 2)
gt_dataset.attrs["axis_names"] = ("z", "y", "x")

crop1 = RawGTDatasetConfig(
name="crop1", raw_config=crop1_raw, gt_config=crop1_distances
)
crop2 = RawGTDatasetConfig(
name="crop2", raw_config=crop2_raw, gt_config=crop2_distances
)
crop3 = RawGTDatasetConfig(
name="crop3", raw_config=crop3_raw, gt_config=crop3_distances
)

six_class_distances_datasplit_config = TrainValidateDataSplitConfig(
name="six_class_distances_datasplit",
train_configs=[crop1, crop2],
validate_configs=[crop3],
)
return six_class_distances_datasplit_config
6 changes: 5 additions & 1 deletion tests/operations/test_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def test_stored_architecture(
from dacapo.store.create_store import create_config_store

config_store = create_config_store()
config_store.store_architecture_config(architecture_config)
try:
config_store.store_architecture_config(architecture_config)
except:
config_store.delete_architecture_config(architecture_config.name)
config_store.store_architecture_config(architecture_config)

retrieved_arch_config = config_store.retrieve_architecture_config(
architecture_config.name
Expand Down
73 changes: 66 additions & 7 deletions tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ def test_train(
@pytest.mark.parametrize("datasplit", [lf("six_class_datasplit")])
@pytest.mark.parametrize("task", [lf("distance_task")])
@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("batch_norm", [ False])
@pytest.mark.parametrize("upsample", [False])
@pytest.mark.parametrize("use_attention", [True, False])
@pytest.mark.parametrize("three_d", [True, False])
@pytest.mark.parametrize("use_attention", [ False])
@pytest.mark.parametrize("three_d", [ False])
def test_train_unet(
datasplit, task, trainer, batch_norm, upsample, use_attention, three_d
):
architecture_config = unet_architecture(
architecture_config = unet_architecture_builder(
batch_norm, upsample, use_attention, three_d
)

Expand Down Expand Up @@ -117,7 +117,7 @@ def test_train_unet(
trainer_config=trainer,
datasplit_config=datasplit,
repetition=0,
num_iterations=30,
num_iterations=2,
)
try:
store.store_run_config(run_config)
Expand All @@ -138,8 +138,67 @@ def test_train_unet(
final_weights = weights_store.retrieve_weights(run.name, run.train_until)

for name, weight in init_weights.model.items():
weight_diff = (weight - final_weights.model[name]).sum()
assert abs(weight_diff) > np.finfo(weight_diff.numpy().dtype).eps, weight_diff
weight_diff = (weight - final_weights.model[name]).any()
assert weight_diff != 0, "Weights did not change"

# assert train_stats and validation_scores are available

training_stats = stats_store.retrieve_training_stats(run_config.name)

assert training_stats.trained_until() == run_config.num_iterations





@pytest.mark.parametrize("upsample_datasplit", [lf("upsample_six_class_datasplit")])
@pytest.mark.parametrize("task", [lf("distance_task")])
@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("upsample", [True])
@pytest.mark.parametrize("use_attention", [True, False])
@pytest.mark.parametrize("three_d", [True, False])
def test_upsample_train_unet(
upsample_datasplit, task, trainer, batch_norm, upsample, use_attention, three_d
):
store = create_config_store()
stats_store = create_stats_store()
weights_store = create_weights_store()

architecture_config = unet_architecture_builder(
batch_norm, upsample, use_attention, three_d
)

run_config = RunConfig(
name=f"{architecture_config.name}_run",
task_config=task,
architecture_config=architecture_config,
trainer_config=trainer,
datasplit_config=upsample_datasplit,
repetition=0,
num_iterations=2,
)
try:
store.store_run_config(run_config)
except Exception as e:
store.delete_run_config(run_config.name)
store.store_run_config(run_config)

run = Run(run_config)

# -------------------------------------

# train

weights_store.store_weights(run, 0)
train_run(run)

init_weights = weights_store.retrieve_weights(run.name, 0)
final_weights = weights_store.retrieve_weights(run.name, run.train_until)

for name, weight in init_weights.model.items():
weight_diff = (weight - final_weights.model[name]).any()
assert weight_diff != 0, "Weights did not change"

# assert train_stats and validation_scores are available

Expand Down
36 changes: 36 additions & 0 deletions tests/operations/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from dacapo.store.create_store import create_config_store, create_weights_store
from dacapo import validate, validate_run

from dacapo.experiments.run_config import RunConfig

import pytest
from pytest_lazy_fixtures import lf

Expand Down Expand Up @@ -97,3 +99,37 @@ def test_validate_run(

if debug:
os.chdir(old_path)




@pytest.mark.parametrize("datasplit", [lf("six_class_datasplit")])
@pytest.mark.parametrize("task", [lf("distance_task")])
@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")])
@pytest.mark.parametrize("architecture", [lf("unet_architecture")])
def test_validate_unet(
datasplit, task, trainer, architecture
):
store = create_config_store()
weights_store = create_weights_store()

run_config = RunConfig(
name=f"{architecture.name}_run",
task_config=task,
architecture_config=architecture,
trainer_config=trainer,
datasplit_config=datasplit,
repetition=0,
num_iterations=2,
)
try:
store.store_run_config(run_config)
except Exception as e:
store.delete_run_config(run_config.name)
store.store_run_config(run_config)

run = Run(run_config)

# -------------------------------------
weights_store.store_weights(run, 0)
validate_run(run, 0)

0 comments on commit 2b80cfa

Please sign in to comment.