Skip to content

Commit

Permalink
unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Apr 4, 2024
1 parent c8029a0 commit fbfff32
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .db import options
from .architectures import dummy_architecture
from .architectures import dummy_architecture, cellpose_architecture
from .arrays import dummy_array, zarr_array, cellmap_array
from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit
from .evaluators import binary_3_channel_evaluator
from .losses import dummy_loss
from .post_processors import argmax, threshold
from .predictors import distance_predictor, onehot_predictor
from .runs import dummy_run, distance_run, onehot_run
from .runs import dummy_run, distance_run, onehot_run, cellpose_run
from .tasks import dummy_task, distance_task, onehot_task
from .trainers import dummy_trainer, gunpowder_trainer
7 changes: 4 additions & 3 deletions tests/fixtures/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ def dummy_architecture():
name="dummy_architecture", num_in_channels=1, num_out_channels=12
)


@pytest.fixture()
@pytest.fixture
def cellpose_architecture():
yield CellposUNetConfig(
name="cellpose_architecture", nbase=[1, 32, 64, 128, 256], sz=3
name="cellpose_architecture", input_shape=(216, 216, 216),
nbase=[1, 12, 24, 48, 96], nout = 12, conv_3D = True
# nbase=[1, 32, 64, 128, 256], nout = 32, conv_3D = True
)
17 changes: 17 additions & 0 deletions tests/fixtures/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,20 @@ def onehot_run(
repetition=0,
num_iterations=100,
)

@pytest.fixture()
def cellpose_run(
dummy_datasplit,
cellpose_architecture,
dummy_task,
dummy_trainer,
):
yield RunConfig(
name="cellpose_run",
task_config=dummy_task,
architecture_config=cellpose_architecture,
trainer_config=dummy_trainer,
datasplit_config=dummy_datasplit,
repetition=0,
num_iterations=100,
)
61 changes: 61 additions & 0 deletions tests/operations/test_cellpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np
from dacapo.store.create_store import create_stats_store
from ..fixtures import *

from dacapo.experiments import Run
from dacapo.store.create_store import create_config_store, create_weights_store
from dacapo.train import train_run
from pytest_lazy_fixtures import lf
import pytest

import logging

logging.basicConfig(level=logging.INFO)


# skip the test for the Apple Paravirtual device
# that does not support Metal 2.0
@pytest.mark.filterwarnings("ignore:.*Metal 2.0.*:UserWarning")
@pytest.mark.parametrize(
"run_config",
[
lf("cellpose_run"),
],
)
def test_train(
run_config,
):
print("Test train")
# create a store

store = create_config_store()
stats_store = create_stats_store()
weights_store = create_weights_store()

# store the configs

store.store_run_config(run_config)
run = Run(run_config)
print("Run created ")
print(run.model)

# # -------------------------------------

# # train

# weights_store.store_weights(run, 0)
# print("Weights stored")
# train_run(run)

# init_weights = weights_store.retrieve_weights(run.name, 0)
# final_weights = weights_store.retrieve_weights(run.name, run.train_until)

# for name, weight in init_weights.model.items():
# weight_diff = (weight - final_weights.model[name]).sum()
# assert abs(weight_diff) > np.finfo(weight_diff.numpy().dtype).eps, weight_diff

# # assert train_stats and validation_scores are available

# training_stats = stats_store.retrieve_training_stats(run_config.name)

# assert training_stats.trained_until() == run_config.num_iterations

0 comments on commit fbfff32

Please sign in to comment.