diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 4eb82be2d6..8dda9136cc 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -300,6 +300,9 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates": sampling_frequency = zarr_group.attrs["sampling_frequency"] nbefore = zarr_group.attrs["nbefore"] + # TODO: Consider eliminating the True and make it required + is_scaled = zarr_group.attrs.get("is_scaled", True) + sparsity_mask = None if "sparsity_mask" in zarr_group: sparsity_mask = zarr_group["sparsity_mask"] @@ -316,6 +319,7 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates": channel_ids=channel_ids, unit_ids=unit_ids, probe=probe, + is_scaled=is_scaled, ) @staticmethod diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 0ed6bc2e3e..34a89ea5d5 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -7,7 +7,7 @@ from probeinterface import generate_multi_columns_probe -def generate_test_template(template_type): +def generate_test_template(template_type, is_scaled=True) -> Templates: num_units = 2 num_samples = 5 num_channels = 3 @@ -21,7 +21,11 @@ def generate_test_template(template_type): if template_type == "dense": return Templates( - templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe + templates_array=templates_array, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + probe=probe, + is_scaled=is_scaled, ) elif template_type == "sparse": # sparse with sparse templates sparsity_mask = np.array([[True, False, True], [False, True, False]]) @@ -42,6 +46,7 @@ def generate_test_template(template_type): sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, + is_scaled=is_scaled, ) elif template_type == "sparse_with_dense_templates": # sparse with dense templates @@ -53,12 +58,14 @@ def generate_test_template(template_type): sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, + is_scaled=is_scaled, ) +@pytest.mark.parametrize("is_scaled", [True, False]) @pytest.mark.parametrize("template_type", ["dense", "sparse"]) -def test_pickle_serialization(template_type, tmp_path): - template = generate_test_template(template_type) +def test_pickle_serialization(template_type, is_scaled, tmp_path): + template = generate_test_template(template_type, is_scaled) # Dump to pickle pkl_path = tmp_path / "templates.pkl" @@ -72,9 +79,10 @@ def test_pickle_serialization(template_type, tmp_path): assert template == template_reloaded +@pytest.mark.parametrize("is_scaled", [True, False]) @pytest.mark.parametrize("template_type", ["dense", "sparse"]) -def test_json_serialization(template_type): - template = generate_test_template(template_type) +def test_json_serialization(template_type, is_scaled): + template = generate_test_template(template_type, is_scaled) json_str = template.to_json() template_reloaded_from_json = Templates.from_json(json_str) @@ -82,9 +90,10 @@ def test_json_serialization(template_type): assert template == template_reloaded_from_json +@pytest.mark.parametrize("is_scaled", [True, False]) @pytest.mark.parametrize("template_type", ["dense", "sparse"]) -def test_get_dense_templates(template_type): - template = generate_test_template(template_type) +def test_get_dense_templates(template_type, is_scaled): + template = generate_test_template(template_type, is_scaled) dense_templates = template.get_dense_templates() assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) @@ -94,9 +103,10 @@ def test_initialization_fail_with_dense_templates(): template = generate_test_template(template_type="sparse_with_dense_templates") +@pytest.mark.parametrize("is_scaled", [True, False]) @pytest.mark.parametrize("template_type", ["dense", "sparse"]) -def test_save_and_load_zarr(template_type, tmp_path): - original_template = generate_test_template(template_type) +def test_save_and_load_zarr(template_type, is_scaled, tmp_path): + original_template = generate_test_template(template_type, is_scaled) zarr_path = tmp_path / "templates.zarr" original_template.to_zarr(str(zarr_path))