-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
84 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
from .db import options | ||
from .architectures import dummy_architecture | ||
from .architectures import dummy_architecture, cellpose_architecture | ||
from .arrays import dummy_array, zarr_array, cellmap_array | ||
from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit | ||
from .evaluators import binary_3_channel_evaluator | ||
from .losses import dummy_loss | ||
from .post_processors import argmax, threshold | ||
from .predictors import distance_predictor, onehot_predictor | ||
from .runs import dummy_run, distance_run, onehot_run | ||
from .runs import dummy_run, distance_run, onehot_run, cellpose_run | ||
from .tasks import dummy_task, distance_task, onehot_task | ||
from .trainers import dummy_trainer, gunpowder_trainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import numpy as np | ||
from dacapo.store.create_store import create_stats_store | ||
from ..fixtures import * | ||
|
||
from dacapo.experiments import Run | ||
from dacapo.store.create_store import create_config_store, create_weights_store | ||
from dacapo.train import train_run | ||
from pytest_lazy_fixtures import lf | ||
import pytest | ||
|
||
import logging | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
# skip the test for the Apple Paravirtual device | ||
# that does not support Metal 2.0 | ||
@pytest.mark.filterwarnings("ignore:.*Metal 2.0.*:UserWarning") | ||
@pytest.mark.parametrize( | ||
"run_config", | ||
[ | ||
lf("cellpose_run"), | ||
], | ||
) | ||
def test_train( | ||
run_config, | ||
): | ||
print("Test train") | ||
# create a store | ||
|
||
store = create_config_store() | ||
stats_store = create_stats_store() | ||
weights_store = create_weights_store() | ||
|
||
# store the configs | ||
|
||
store.store_run_config(run_config) | ||
run = Run(run_config) | ||
print("Run created ") | ||
print(run.model) | ||
|
||
# # ------------------------------------- | ||
|
||
# # train | ||
|
||
# weights_store.store_weights(run, 0) | ||
# print("Weights stored") | ||
# train_run(run) | ||
|
||
# init_weights = weights_store.retrieve_weights(run.name, 0) | ||
# final_weights = weights_store.retrieve_weights(run.name, run.train_until) | ||
|
||
# for name, weight in init_weights.model.items(): | ||
# weight_diff = (weight - final_weights.model[name]).sum() | ||
# assert abs(weight_diff) > np.finfo(weight_diff.numpy().dtype).eps, weight_diff | ||
|
||
# # assert train_stats and validation_scores are available | ||
|
||
# training_stats = stats_store.retrieve_training_stats(run_config.name) | ||
|
||
# assert training_stats.trained_until() == run_config.num_iterations |