Skip to content

Commit

Permalink
Merge pull request #726 from RasmusOrsoe/fix_str_to_dataconverter
Browse files Browse the repository at this point in the history
Cast single path to list of paths in `DataConverter.merge_files`
  • Loading branch information
RasmusOrsoe authored May 28, 2024
2 parents e5b9450 + cac79db commit 1d942a4
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/graphnet/data/dataconverter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains `DataConverter`."""

from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type
from abc import abstractmethod, ABC

Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
self._index = 0
self._output_dir = outdir
self._output_files: List[str] = []
self._extension = self._save_method.file_extension

# Set Extractors. Will throw error if extractors are incompatible
# with reader.
Expand All @@ -102,8 +104,7 @@ def __call__(self, input_dir: Union[str, List[str]]) -> None:
self._output_files = [
os.path.join(
self._output_dir,
self._create_file_name(file)
+ self._save_method.file_extension,
self._create_file_name(file) + self._extension,
)
for file in input_files
]
Expand Down Expand Up @@ -256,16 +257,12 @@ def _request_event_nos(self, n_ids: int) -> List[int]:
# Get new, unique index and increment value
if self._num_workers > 1:
with global_index.get_lock(): # type: ignore[name-defined]
starting_index = global_index.value # type: ignore[name-defined]
event_nos = np.arange(
starting_index, starting_index + n_ids, 1
).tolist()
start_idx = global_index.value # type: ignore[name-defined]
event_nos = np.arange(start_idx, start_idx + n_ids, 1).tolist()
global_index.value += n_ids # type: ignore[name-defined]
else:
starting_index = self._index
event_nos = np.arange(
starting_index, starting_index + n_ids, 1
).tolist()
start_idx = self._index
event_nos = np.arange(start_idx, start_idx + n_ids, 1).tolist()
self._index += n_ids

return event_nos
Expand Down Expand Up @@ -321,7 +318,7 @@ def _update_shared_variables(

@final
def merge_files(
self, files: Optional[List[str]] = None, **kwargs: Any
self, files: Optional[Union[List[str], str]] = None, **kwargs: Any
) -> None:
"""Merge converted files.
Expand All @@ -337,6 +334,10 @@ def merge_files(
files_to_merge = self._output_files
elif files is not None:
# Proceed to merge specified by user.
if isinstance(files, str):
# We shouldn't merge a single file?
self.info(f"Got just a single file {files}. Merging skipped.")
return
files_to_merge = files
else:
# Raise error
Expand Down

0 comments on commit 1d942a4

Please sign in to comment.