Skip to content

Commit

Permalink
overwrite previous changes to DataConverter
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe committed Sep 13, 2024
1 parent 40aee1f commit 2c1d202
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/graphnet/data/dataconverter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains `DataConverter`."""

from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type
from abc import ABC

Expand Down Expand Up @@ -260,8 +261,8 @@ def _request_event_nos(self, n_ids: int) -> List[int]:
event_nos = np.arange(start_idx, start_idx + n_ids, 1).tolist()
global_index.value += n_ids # type: ignore[name-defined]
else:
starting_index = self._index
event_nos = np.arange(starting_index, starting_index + n_ids, 1).tolist()
start_idx = self._index
event_nos = np.arange(start_idx, start_idx + n_ids, 1).tolist()
self._index += n_ids

return event_nos
Expand Down Expand Up @@ -316,7 +317,9 @@ def _update_shared_variables(
self._output_files.extend(list(sorted(output_files[:])))

@final
def merge_files(self, files: Optional[List[str]] = None, **kwargs: Any) -> None:
def merge_files(
self, files: Optional[Union[List[str], str]] = None, **kwargs: Any
) -> None:
"""Merge converted files.
`DataConverter` will call the `.merge_files` method in the
Expand All @@ -332,7 +335,9 @@ def merge_files(self, files: Optional[List[str]] = None, **kwargs: Any) -> None:
elif files is not None:
# Proceed to merge specified by user.
if isinstance(files, str):
files = [files] # Cast to list if user forgot
# We shouldn't merge a single file?
self.info(f"Got just a single file {files}. Merging skipped.")
return
files_to_merge = files
else:
# Raise error
Expand Down

0 comments on commit 2c1d202

Please sign in to comment.