Skip to content

Commit

Permalink
Merge branch 'graphnet-team:main' into RNN_Attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted authored Feb 22, 2024
2 parents a0b7262 + 9735f51 commit 5d0f4b5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
16 changes: 10 additions & 6 deletions src/graphnet/data/dataconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,14 @@ def __call__(self, input_dir: Union[str, List[str]]) -> None:
# in the directory
input_files = self._file_reader.find_files(path=input_dir)
self._launch_jobs(input_files=input_files)
self._output_files = glob(
self._output_files = [
os.path.join(
self._output_dir, f"*{self._save_method.file_extension}"
self._output_dir,
self._create_file_name(file)
+ self._save_method.file_extension,
)
)
for file in input_files
]

@final
def _launch_jobs(
Expand Down Expand Up @@ -159,9 +162,10 @@ def _create_file_name(self, input_file_path: Union[str, I3FileSet]) -> str:
if isinstance(input_file_path, I3FileSet):
input_file_path = input_file_path.i3_file
file_name = os.path.basename(input_file_path)
index_of_dot = file_name.index(".")
file_name_without_extension = file_name[:index_of_dot]
return file_name_without_extension
for ext in self._file_reader._accepted_file_extensions:
if file_name.endswith(ext):
file_name_without_extension = file_name.replace(ext, "")
return file_name_without_extension.replace(".i3", "")

@final
def _assign_event_no(
Expand Down
9 changes: 8 additions & 1 deletion src/graphnet/data/extractors/icecube/i3genericextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
self,
keys: Optional[Union[str, List[str]]] = None,
exclude_keys: Optional[Union[str, List[str]]] = None,
extractor_name: str = GENERIC_EXTRACTOR_NAME,
):
"""Construct I3GenericExtractor.
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(
self._exclude_keys: Optional[List[str]] = exclude_keys

# Base class constructor
super().__init__(GENERIC_EXTRACTOR_NAME)
super().__init__(extractor_name)

def _get_keys(self, frame: "icetray.I3Frame") -> List[str]:
"""Get the list of keys to be queried from `frame`.
Expand Down Expand Up @@ -170,6 +171,12 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]:
# Flatten all other objects
else:
results[key] = self._flatten_result(result)
if (
isinstance(results[key], dict)
and "value" in results[key]
and len(results[key]) == 1
):
results[key] = results[key]["value"]

# Serialise list of iterables to JSON
results = {key: serialise(value) for key, value in results.items()}
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/data/writers/sqlite_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def _merge_databases(
database_path=database_path,
index_column=primary_key,
integer_primary_key=integer_primary_key,
default_type="FLOAT",
)

# Update row counts if needed
Expand Down

0 comments on commit 5d0f4b5

Please sign in to comment.