diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 69d13be50..57f6005f6 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -89,11 +89,14 @@ def __call__(self, input_dir: Union[str, List[str]]) -> None: # in the directory input_files = self._file_reader.find_files(path=input_dir) self._launch_jobs(input_files=input_files) - self._output_files = glob( + self._output_files = [ os.path.join( - self._output_dir, f"*{self._save_method.file_extension}" + self._output_dir, + self._create_file_name(file) + + self._save_method.file_extension, ) - ) + for file in input_files + ] @final def _launch_jobs( @@ -159,9 +162,10 @@ def _create_file_name(self, input_file_path: Union[str, I3FileSet]) -> str: if isinstance(input_file_path, I3FileSet): input_file_path = input_file_path.i3_file file_name = os.path.basename(input_file_path) - index_of_dot = file_name.index(".") - file_name_without_extension = file_name[:index_of_dot] - return file_name_without_extension + for ext in self._file_reader._accepted_file_extensions: + if file_name.endswith(ext): + file_name_without_extension = file_name.replace(ext, "") + return file_name_without_extension.replace(".i3", "") @final def _assign_event_no( diff --git a/src/graphnet/data/extractors/icecube/i3genericextractor.py b/src/graphnet/data/extractors/icecube/i3genericextractor.py index e907181d0..c79b7329b 100644 --- a/src/graphnet/data/extractors/icecube/i3genericextractor.py +++ b/src/graphnet/data/extractors/icecube/i3genericextractor.py @@ -45,6 +45,7 @@ def __init__( self, keys: Optional[Union[str, List[str]]] = None, exclude_keys: Optional[Union[str, List[str]]] = None, + extractor_name: str = GENERIC_EXTRACTOR_NAME, ): """Construct I3GenericExtractor. @@ -72,7 +73,7 @@ def __init__( self._exclude_keys: Optional[List[str]] = exclude_keys # Base class constructor - super().__init__(GENERIC_EXTRACTOR_NAME) + super().__init__(extractor_name) def _get_keys(self, frame: "icetray.I3Frame") -> List[str]: """Get the list of keys to be queried from `frame`. @@ -170,6 +171,12 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]: # Flatten all other objects else: results[key] = self._flatten_result(result) + if ( + isinstance(results[key], dict) + and "value" in results[key] + and len(results[key]) == 1 + ): + results[key] = results[key]["value"] # Serialise list of iterables to JSON results = {key: serialise(value) for key, value in results.items()} diff --git a/src/graphnet/data/writers/sqlite_writer.py b/src/graphnet/data/writers/sqlite_writer.py index d7cc48297..ab8d95051 100644 --- a/src/graphnet/data/writers/sqlite_writer.py +++ b/src/graphnet/data/writers/sqlite_writer.py @@ -181,6 +181,7 @@ def _merge_databases( database_path=database_path, index_column=primary_key, integer_primary_key=integer_primary_key, + default_type="FLOAT", ) # Update row counts if needed