Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 15, 2023
1 parent 64b6944 commit 9d239ca
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
cache_folder = pytest.global_test_folder / "deepinterpolation"
else:
cache_folder = Path("cache_folder") / "deepinterpolation"
if not cache_folder.is_dir():
cache_folder.mkdir(parents=True)


def recording_and_shape():
Expand Down Expand Up @@ -147,6 +145,7 @@ def test_deepinterpolation_inference(recording_and_shape_fixture):
np.any(np.not_equal(traces_di_last[:-post_frame:], traces_original_last[:-post_frame:]))


@pytest.mark.skipif(not HAVE_DEEPINTERPOLATION, reason="requires deepinterpolation")
@pytest.mark.dependency(depends=["test_deepinterpolation_training"])
def test_deepinterpolation_inference_multi_job(recording_and_shape_fixture):
recording, desired_shape = recording_and_shape_fixture
Expand All @@ -172,7 +171,7 @@ def test_deepinterpolation_inference_multi_job(recording_and_shape_fixture):

if __name__ == "__main__":
recording_shape = recording_and_shape()
# test_deepinterpolation_training(recording_shape)
test_deepinterpolation_training(recording_shape)
# test_deepinterpolation_transfer()
test_deepinterpolation_inference(recording_shape)
# test_deepinterpolation_inference_multi_job()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def train_deepinterpolation_process(
_ = import_tf(use_gpu, disable_tf_logger, memory_gpu=memory_gpu)

trained_model_folder = Path(model_folder)
trained_model_folder.mkdir(exist_ok=True)
trained_model_folder.mkdir(exist_ok=True, parents=True)

# check if roughly zscored
if not isinstance(recordings, list):
Expand Down

0 comments on commit 9d239ca

Please sign in to comment.