diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index f3d8b3df7f..066ab58d8c 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -1,6 +1,6 @@ from __future__ import annotations from pathlib import Path, WindowsPath -from typing import Union +from typing import Union, Generator import os import sys import datetime @@ -8,6 +8,7 @@ from copy import deepcopy import importlib from math import prod +from collections import namedtuple import numpy as np @@ -183,6 +184,75 @@ def is_dict_extractor(d: dict) -> bool: return is_extractor +extractor_dict_element = namedtuple(typename="extractor_dict_element", field_names=["value", "name", "access_path"]) + + +def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element]: + """ + Iterator for recursive traversal of a dictionary. + This function explores the dictionary recursively and yields the path to each value along with the value itself. + + By path here we mean the keys that lead to the value in the dictionary: + e.g. for the dictionary {'a': {'b': 1}}, the path to the value 1 is ('a', 'b'). + + See `BaseExtractor.to_dict()` for a description of `extractor_dict` structure. + + Parameters + ---------- + extractor_dict : dict + Input dictionary + + Yields + ------ + extractor_dict_element + Named tuple containing the value, the name, and the access_path to the value in the dictionary. + + """ + + def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""): + if isinstance(dict_list_or_value, dict): + for k, v in dict_list_or_value.items(): + yield from _extractor_dict_iterator(v, access_path + (k,), name=k) + elif isinstance(dict_list_or_value, list): + for i, v in enumerate(dict_list_or_value): + yield from _extractor_dict_iterator( + v, access_path + (i,), name=name + ) # Propagate name of list to children + else: + yield extractor_dict_element( + value=dict_list_or_value, + name=name, + access_path=access_path, + ) + + yield from _extractor_dict_iterator(extractor_dict) + + +def set_value_in_extractor_dict(extractor_dict: dict, access_path: tuple, new_value): + """ + In place modification of a value in a nested dictionary given its access path. + + Parameters + ---------- + extractor_dict : dict + The dictionary to modify + access_path : tuple + The path to the value in the dictionary + new_value : object + The new value to set + + Returns + ------- + dict + The modified dictionary + """ + + current = extractor_dict + for key in access_path[:-1]: + current = current[key] + current[access_path[-1]] = new_value + + def recursive_path_modifier(d, func, target="path", copy=True) -> dict: """ Generic function for recursive modification of paths in an extractor dict. @@ -250,15 +320,17 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict: raise ValueError(f"{k} key for path must be str or list[str]") -def _get_paths_list(d): - # this explore a dict and get all paths flatten in a list - # the trick is to use a closure func called by recursive_path_modifier() - path_list = [] +# This is the current definition that an element in a extractor_dict is a path +# This is shared across a couple of definition so it is here for DNRY +element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path)) + - def append_to_path(p): - path_list.append(p) +def _get_paths_list(d: dict) -> list[str | Path]: + path_list = [e.value for e in extractor_dict_iterator(d) if element_is_path(e)] + + # if check_if_exists: TODO: Enable this once container_tools test uses proper mocks + # path_list = [p for p in path_list if Path(p).exists()] - recursive_path_modifier(d, append_to_path, target="path", copy=True) return path_list @@ -318,7 +390,7 @@ def check_paths_relative(input_dict, relative_folder) -> bool: return len(not_possible) == 0 -def make_paths_relative(input_dict, relative_folder) -> dict: +def make_paths_relative(input_dict: dict, relative_folder: str | Path) -> dict: """ Recursively transform a dict describing an BaseExtractor to make every path relative to a folder. @@ -334,9 +406,22 @@ def make_paths_relative(input_dict, relative_folder) -> dict: output_dict: dict A copy of the input dict with modified paths. """ + relative_folder = Path(relative_folder).resolve().absolute() - func = lambda p: _relative_to(p, relative_folder) - output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True) + + path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)] + # Only paths that exist are made relative + path_elements_in_dict = [e for e in path_elements_in_dict if Path(e.value).exists()] + + output_dict = deepcopy(input_dict) + for element in path_elements_in_dict: + new_value = _relative_to(element.value, relative_folder) + set_value_in_extractor_dict( + extractor_dict=output_dict, + access_path=element.access_path, + new_value=new_value, + ) + return output_dict @@ -359,12 +444,28 @@ def make_paths_absolute(input_dict, base_folder): base_folder = Path(base_folder) # use as_posix instead of str to make the path unix like even on window func = lambda p: (base_folder / p).resolve().absolute().as_posix() - output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True) + + path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)] + output_dict = deepcopy(input_dict) + + output_dict = deepcopy(input_dict) + for element in path_elements_in_dict: + absolute_path = (base_folder / element.value).resolve() + if Path(absolute_path).exists(): + new_value = absolute_path.as_posix() # Not so sure about this, Sam + set_value_in_extractor_dict( + extractor_dict=output_dict, + access_path=element.access_path, + new_value=new_value, + ) + return output_dict def recursive_key_finder(d, key): # Find all values for a key on a dictionary, even if nested + # TODO refactor to use extractor_dict_iterator + for k, v in d.items(): if isinstance(v, dict): yield from recursive_key_finder(v, key) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 8e00dcb779..7153991543 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -31,100 +31,148 @@ def test_add_suffix(): assert str(file_path_with_suffix) == expected_path +@pytest.mark.skipif(platform.system() == "Windows", reason="Runs on posix only") def test_path_utils_functions(): - if platform.system() != "Windows": - # posix path - d = { - "kwargs": { - "path": "/yep/sub/path1", - "recording": { - "module": "mock_module", - "class": "mock_class", - "version": "1.2", - "annotations": {}, - "kwargs": {"path": "/yep/sub/path2"}, - }, - } + # posix path + d = { + "kwargs": { + "path": "/yep/sub/path1", + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": "/yep/sub/path2"}, + }, } - - d2 = recursive_path_modifier(d, lambda p: p.replace("/yep", "/yop")) - assert d2["kwargs"]["path"].startswith("/yop") - assert d2["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") - - d3 = make_paths_relative(d, Path("/yep")) - assert d3["kwargs"]["path"] == "sub/path1" - assert d3["kwargs"]["recording"]["kwargs"]["path"] == "sub/path2" - - d4 = make_paths_absolute(d3, "/yop") - assert d4["kwargs"]["path"].startswith("/yop") - assert d4["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") - - if platform.system() == "Windows": - # test for windows Path - d = { - "kwargs": { - "path": r"c:\yep\sub\path1", - "recording": { - "module": "mock_module", - "class": "mock_class", - "version": "1.2", - "annotations": {}, - "kwargs": {"path": r"c:\yep\sub\path2"}, - }, - } + } + + d2 = recursive_path_modifier(d, lambda p: p.replace("/yep", "/yop")) + assert d2["kwargs"]["path"].startswith("/yop") + assert d2["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") + + +@pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") +def test_relative_path_on_windows(): + + d = { + "kwargs": { + "path": r"c:\yep\sub\path1", + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": r"c:\yep\sub\path2"}, + }, } + } - d2 = make_paths_relative(d, "c:\\yep") - # the str be must unix like path even on windows for more portability - assert d2["kwargs"]["path"] == "sub/path1" - assert d2["kwargs"]["recording"]["kwargs"]["path"] == "sub/path2" + # same drive + assert check_paths_relative(d, r"c:\yep") + # not the same drive + assert not check_paths_relative(d, r"d:\yep") - # same drive - assert check_paths_relative(d, r"c:\yep") - # not the same drive - assert not check_paths_relative(d, r"d:\yep") - d = { - "kwargs": { - "path": r"\\host\share\yep\sub\path1", - } +@pytest.mark.skipif(platform.system() != "Windows", reason="Runs only on Windows") +def test_universal_naming_convention(): + d = { + "kwargs": { + "path": r"\\host\share\yep\sub\path1", + } + } + # UNC cannot be relative to d: drive + assert not check_paths_relative(d, r"d:\yep") + + # UNC can be relative to the same UNC + assert check_paths_relative(d, r"\\host\share") + + +def test_make_paths_relative(tmp_path): + + path_1 = tmp_path / "sub" / "path1" + path_2 = tmp_path / "sub" / "path2" + + # Create the objects in the path + path_1.mkdir(parents=True, exist_ok=True) + path_2.mkdir(parents=True, exist_ok=True) + extractor_dict = { + "kwargs": { + "path": str(path_1), # Note this is different in windows and posix + "electrical_series_path": "/acquisition/timeseries", # non-existent path-like objects should not be modified + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": str(path_2)}, + }, + } + } + modified_extractor_dict = make_paths_relative(extractor_dict, tmp_path) + assert modified_extractor_dict["kwargs"]["path"] == "sub/path1" + assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"] == "sub/path2" + assert modified_extractor_dict["kwargs"]["electrical_series_path"] == "/acquisition/timeseries" + + +def test_make_paths_absolute(tmp_path): + + path_1 = tmp_path / "sub" / "path1" + path_2 = tmp_path / "sub" / "path2" + + path_1.mkdir(parents=True, exist_ok=True) + path_2.mkdir(parents=True, exist_ok=True) + + extractor_dict = { + "kwargs": { + "path": "sub/path1", + "electrical_series_path": "/acquisition/timeseries", # non-existent path-like objects should not be modified + "recording": { + "module": "mock_module", + "class": "mock_class", + "version": "1.2", + "annotations": {}, + "kwargs": {"path": "sub/path2"}, + }, } - # UNC cannot be relative to d: drive - assert not check_paths_relative(d, r"d:\yep") - - # UNC can be relative to the same UNC - assert check_paths_relative(d, r"\\host\share") - - def test_convert_string_to_bytes(): - # Test SI prefixes - assert convert_string_to_bytes("1k") == 1000 - assert convert_string_to_bytes("1M") == 1000000 - assert convert_string_to_bytes("1G") == 1000000000 - assert convert_string_to_bytes("1T") == 1000000000000 - assert convert_string_to_bytes("1P") == 1000000000000000 - # Test IEC prefixes - assert convert_string_to_bytes("1Ki") == 1024 - assert convert_string_to_bytes("1Mi") == 1048576 - assert convert_string_to_bytes("1Gi") == 1073741824 - assert convert_string_to_bytes("1Ti") == 1099511627776 - assert convert_string_to_bytes("1Pi") == 1125899906842624 - # Test mixed values - assert convert_string_to_bytes("1.5k") == 1500 - assert convert_string_to_bytes("2.5M") == 2500000 - assert convert_string_to_bytes("0.5G") == 500000000 - assert convert_string_to_bytes("1.2T") == 1200000000000 - assert convert_string_to_bytes("1.5Pi") == 1688849860263936 - # Test zero values - assert convert_string_to_bytes("0k") == 0 - assert convert_string_to_bytes("0Ki") == 0 - # Test invalid inputs (should raise assertion error) - with pytest.raises(AssertionError) as e: - convert_string_to_bytes("1Z") - assert str(e.value) == "Unknown suffix: Z" - - with pytest.raises(AssertionError) as e: - convert_string_to_bytes("1Xi") - assert str(e.value) == "Unknown suffix: Xi" + } + + modified_extractor_dict = make_paths_absolute(extractor_dict, tmp_path) + assert modified_extractor_dict["kwargs"]["path"].startswith(str(tmp_path.as_posix())) + assert modified_extractor_dict["kwargs"]["recording"]["kwargs"]["path"].startswith(str(tmp_path.as_posix())) + assert modified_extractor_dict["kwargs"]["electrical_series_path"] == "/acquisition/timeseries" + + +def test_convert_string_to_bytes(): + # Test SI prefixes + assert convert_string_to_bytes("1k") == 1000 + assert convert_string_to_bytes("1M") == 1000000 + assert convert_string_to_bytes("1G") == 1000000000 + assert convert_string_to_bytes("1T") == 1000000000000 + assert convert_string_to_bytes("1P") == 1000000000000000 + # Test IEC prefixes + assert convert_string_to_bytes("1Ki") == 1024 + assert convert_string_to_bytes("1Mi") == 1048576 + assert convert_string_to_bytes("1Gi") == 1073741824 + assert convert_string_to_bytes("1Ti") == 1099511627776 + assert convert_string_to_bytes("1Pi") == 1125899906842624 + # Test mixed values + assert convert_string_to_bytes("1.5k") == 1500 + assert convert_string_to_bytes("2.5M") == 2500000 + assert convert_string_to_bytes("0.5G") == 500000000 + assert convert_string_to_bytes("1.2T") == 1200000000000 + assert convert_string_to_bytes("1.5Pi") == 1688849860263936 + # Test zero values + assert convert_string_to_bytes("0k") == 0 + assert convert_string_to_bytes("0Ki") == 0 + # Test invalid inputs (should raise assertion error) + with pytest.raises(AssertionError) as e: + convert_string_to_bytes("1Z") + assert str(e.value) == "Unknown suffix: Z" + + with pytest.raises(AssertionError) as e: + convert_string_to_bytes("1Xi") + assert str(e.value) == "Unknown suffix: Xi" def test_normal_pdf() -> None: diff --git a/src/spikeinterface/sorters/container_tools.py b/src/spikeinterface/sorters/container_tools.py index 60eb080ae5..8e03090eaf 100644 --- a/src/spikeinterface/sorters/container_tools.py +++ b/src/spikeinterface/sorters/container_tools.py @@ -9,19 +9,14 @@ # TODO move this inside functions -from spikeinterface.core.core_tools import recursive_path_modifier +from spikeinterface.core.core_tools import recursive_path_modifier, _get_paths_list def find_recording_folders(d): """Finds all recording folders 'paths' in a dict""" - folders_to_mount = [] - def append_parent_folder(p): - p = Path(p) - folders_to_mount.append(p.resolve().absolute().parent) - return p - - _ = recursive_path_modifier(d, append_parent_folder, target="path", copy=True) + path_list = _get_paths_list(d=d) + folders_to_mount = [Path(p).resolve().parent for p in path_list] try: # this will fail if on different drives (Windows) base_folders_to_mount = [Path(os.path.commonpath(folders_to_mount))]