diff --git a/_modules/graphnet/data/dataconverter.html b/_modules/graphnet/data/dataconverter.html index f2911d05e..62267e4f4 100644 --- a/_modules/graphnet/data/dataconverter.html +++ b/_modules/graphnet/data/dataconverter.html @@ -350,6 +350,7 @@

Source code for graphnet.data.dataconverter

 """Contains `DataConverter`."""
+
 from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type
 from abc import abstractmethod, ABC
 
@@ -434,6 +435,7 @@ 

Source code for graphnet self._index = 0 self._output_dir = outdir self._output_files: List[str] = [] + self._extension = self._save_method.file_extension # Set Extractors. Will throw error if extractors are incompatible # with reader. @@ -458,8 +460,7 @@

Source code for graphnet self._output_files = [ os.path.join( self._output_dir, - self._create_file_name(file) - + self._save_method.file_extension, + self._create_file_name(file) + self._extension, ) for file in input_files ] @@ -612,16 +613,12 @@

Source code for graphnet # Get new, unique index and increment value if self._num_workers > 1: with global_index.get_lock(): # type: ignore[name-defined] - starting_index = global_index.value # type: ignore[name-defined] - event_nos = np.arange( - starting_index, starting_index + n_ids, 1 - ).tolist() + start_idx = global_index.value # type: ignore[name-defined] + event_nos = np.arange(start_idx, start_idx + n_ids, 1).tolist() global_index.value += n_ids # type: ignore[name-defined] else: - starting_index = self._index - event_nos = np.arange( - starting_index, starting_index + n_ids, 1 - ).tolist() + start_idx = self._index + event_nos = np.arange(start_idx, start_idx + n_ids, 1).tolist() self._index += n_ids return event_nos @@ -682,7 +679,7 @@

Source code for graphnet [docs] @final def merge_files( - self, files: Optional[List[str]] = None, **kwargs: Any + self, files: Optional[Union[List[str], str]] = None, **kwargs: Any ) -> None: """Merge converted files. @@ -698,6 +695,10 @@

Source code for graphnet files_to_merge = self._output_files elif files is not None: # Proceed to merge specified by user. + if isinstance(files, str): + # We shouldn't merge a single file? + self.info(f"Got just a single file {files}. Merging skipped.") + return files_to_merge = files else: # Raise error diff --git a/_modules/graphnet/training/labels.html b/_modules/graphnet/training/labels.html index 5440be6b8..52ab68d0a 100644 --- a/_modules/graphnet/training/labels.html +++ b/_modules/graphnet/training/labels.html @@ -461,8 +461,9 @@

Source code for graphnet.tr def __call__(self, graph: Data) -> torch.tensor: """Compute label for `graph`.""" - label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1) - return label.type(torch.int)

+ is_numu = torch.abs(graph[self._pid_key]) == 14 + is_cc = graph[self._int_key] == 1 + return (is_numu & is_cc).type(torch.int) diff --git a/api/graphnet.data.dataconverter.html b/api/graphnet.data.dataconverter.html index 5e4780ba3..87bedba8c 100644 --- a/api/graphnet.data.dataconverter.html +++ b/api/graphnet.data.dataconverter.html @@ -644,7 +644,7 @@
Parameters:
diff --git a/api/graphnet.data.extractors.icecube.i3genericextractor.html b/api/graphnet.data.extractors.icecube.i3genericextractor.html index 2c4745b32..3f05ed94a 100644 --- a/api/graphnet.data.extractors.icecube.i3genericextractor.html +++ b/api/graphnet.data.extractors.icecube.i3genericextractor.html @@ -693,8 +693,8 @@
Parameters:
    -
  • keys (Union[str, List[str], None], default: None) – List of keys in I3Frame to be parsed. Defaults to all keys.

  • -
  • exclude_keys (Union[str, List[str], None], default: None) – List of keys in I3Frame to exclude while parsing.

  • +
  • keys (Union[List[str], str, None], default: None) – List of keys in I3Frame to be parsed. Defaults to all keys.

  • +
  • exclude_keys (Union[List[str], str, None], default: None) – List of keys in I3Frame to exclude while parsing.

  • extractor_name (str)

diff --git a/api/graphnet.models.gnn.dynedge.html b/api/graphnet.models.gnn.dynedge.html index 0a5943eee..dd745a65e 100644 --- a/api/graphnet.models.gnn.dynedge.html +++ b/api/graphnet.models.gnn.dynedge.html @@ -629,7 +629,7 @@ post-processing _and_ optional global pooling. As this is the last layer(s) in the model, the last layer in the read-out yields the output of the DynEdge model. Defaults to [128,].

-
  • global_pooling_schemes (Union[str, List[str], None], default: None) – The list global pooling schemes to use. +

  • global_pooling_schemes (Union[List[str], str, None], default: None) – The list global pooling schemes to use. Options are: “min”, “max”, “mean”, and “sum”.

  • add_global_variables_after_pooling (bool, default: False) – Whether to add global variables after global pooling. The alternative is to added (distribute) diff --git a/api/graphnet.models.task.task.html b/api/graphnet.models.task.task.html index d998c9ab7..fd571afab 100644 --- a/api/graphnet.models.task.task.html +++ b/api/graphnet.models.task.task.html @@ -804,10 +804,10 @@

    Parameters:
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • -
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +

    • target_labels (Union[List[str], str, None], default: None) – Name(s) of the quantity/-ies being predicted, used to extract the target tensor(s) from the Data object in .compute_loss(…).

    • -
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +

    • prediction_labels (Union[List[str], str, None], default: None) – The name(s) of each column that is predicted by the model during inference. If not given, the name will auto matically be set to target_label + _pred.

    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform