Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extractor_dict_iterator for solving path detection in object kwargs #3089

Merged
merged 9 commits into from
Jun 28, 2024
125 changes: 113 additions & 12 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations
from pathlib import Path, WindowsPath
from typing import Union
from typing import Union, Generator
import os
import sys
import datetime
import json
from copy import deepcopy
import importlib
from math import prod
from collections import namedtuple

import numpy as np

Expand Down Expand Up @@ -183,6 +184,75 @@ def is_dict_extractor(d: dict) -> bool:
return is_extractor


extractor_dict_element = namedtuple(typename="extractor_dict_element", field_names=["value", "name", "access_path"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very good idea!



def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element]:
"""
Iterator for recursive traversal of a dictionary.
This function explores the dictionary recursively and yields the path to each value along with the value itself.

By path here we mean the keys that lead to the value in the dictionary:
e.g. for the dictionary {'a': {'b': 1}}, the path to the value 1 is ('a', 'b').

See `BaseExtractor.to_dict()` for a description of `extractor_dict` structure.

Parameters
----------
extractor_dict : dict
Input dictionary

Yields
------
extractor_dict_element
Named tuple containing the value, the name, and the access_path to the value in the dictionary.

"""

def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""):
if isinstance(dict_list_or_value, dict):
for k, v in dict_list_or_value.items():
yield from _extractor_dict_iterator(v, access_path + (k,), name=k)
elif isinstance(dict_list_or_value, list):
for i, v in enumerate(dict_list_or_value):
yield from _extractor_dict_iterator(
v, access_path + (i,), name=name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really smart to have the path both with key and index for list!!

) # Propagate name of list to children
else:
yield extractor_dict_element(
value=dict_list_or_value,
name=name,
access_path=access_path,
)

yield from _extractor_dict_iterator(extractor_dict)


def set_value_in_extractor_dict(extractor_dict: dict, access_path: tuple, new_value):
"""
In place modification of a value in a nested dictionary given its access path.

Parameters
----------
extractor_dict : dict
The dictionary to modify
access_path : tuple
The path to the value in the dictionary
new_value : object
The new value to set

Returns
-------
dict
The modified dictionary
"""

current = extractor_dict
for key in access_path[:-1]:
current = current[key]
current[access_path[-1]] = new_value


def recursive_path_modifier(d, func, target="path", copy=True) -> dict:
"""
Generic function for recursive modification of paths in an extractor dict.
Expand Down Expand Up @@ -250,15 +320,17 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict:
raise ValueError(f"{k} key for path must be str or list[str]")


def _get_paths_list(d):
# this explore a dict and get all paths flatten in a list
# the trick is to use a closure func called by recursive_path_modifier()
path_list = []
# This is the current definition that an element in a extractor_dict is a path
# This is shared across a couple of definition so it is here for DNRY
element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path))


def append_to_path(p):
path_list.append(p)
def _get_paths_list(d: dict) -> list[str | Path]:
path_list = [e.value for e in extractor_dict_iterator(d) if element_is_path(e)]

# if check_if_exists: TODO: Enable this once container_tools test uses proper mocks
# path_list = [p for p in path_list if Path(p).exists()]

recursive_path_modifier(d, append_to_path, target="path", copy=True)
return path_list


Expand Down Expand Up @@ -318,7 +390,7 @@ def check_paths_relative(input_dict, relative_folder) -> bool:
return len(not_possible) == 0


def make_paths_relative(input_dict, relative_folder) -> dict:
def make_paths_relative(input_dict: dict, relative_folder: str | Path) -> dict:
"""
Recursively transform a dict describing an BaseExtractor to make every path relative to a folder.

Expand All @@ -334,9 +406,22 @@ def make_paths_relative(input_dict, relative_folder) -> dict:
output_dict: dict
A copy of the input dict with modified paths.
"""

relative_folder = Path(relative_folder).resolve().absolute()
func = lambda p: _relative_to(p, relative_folder)
output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True)

path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)]
# Only paths that exist are made relative
path_elements_in_dict = [e for e in path_elements_in_dict if Path(e.value).exists()]

output_dict = deepcopy(input_dict)
for element in path_elements_in_dict:
new_value = _relative_to(element.value, relative_folder)
set_value_in_extractor_dict(
extractor_dict=output_dict,
access_path=element.access_path,
new_value=new_value,
)

return output_dict


Expand All @@ -359,12 +444,28 @@ def make_paths_absolute(input_dict, base_folder):
base_folder = Path(base_folder)
# use as_posix instead of str to make the path unix like even on window
func = lambda p: (base_folder / p).resolve().absolute().as_posix()
output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True)

path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)]
output_dict = deepcopy(input_dict)

output_dict = deepcopy(input_dict)
for element in path_elements_in_dict:
absolute_path = (base_folder / element.value).resolve()
if Path(absolute_path).exists():
new_value = absolute_path.as_posix() # Not so sure about this, Sam
set_value_in_extractor_dict(
extractor_dict=output_dict,
access_path=element.access_path,
new_value=new_value,
)

return output_dict


def recursive_key_finder(d, key):
# Find all values for a key on a dictionary, even if nested
# TODO refactor to use extractor_dict_iterator

for k, v in d.items():
if isinstance(v, dict):
yield from recursive_key_finder(v, key)
Expand Down
Loading
Loading