diff --git a/_modules/graphnet/data/dataconverter.html b/_modules/graphnet/data/dataconverter.html index f2911d05e..62267e4f4 100644 --- a/_modules/graphnet/data/dataconverter.html +++ b/_modules/graphnet/data/dataconverter.html @@ -350,6 +350,7 @@
"""Contains `DataConverter`."""
+
from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type
from abc import abstractmethod, ABC
@@ -434,6 +435,7 @@ Source code for graphnet
self._index = 0
self._output_dir = outdir
self._output_files: List[str] = []
+ self._extension = self._save_method.file_extension
# Set Extractors. Will throw error if extractors are incompatible
# with reader.
@@ -458,8 +460,7 @@ Source code for graphnet
self._output_files = [
os.path.join(
self._output_dir,
- self._create_file_name(file)
- + self._save_method.file_extension,
+ self._create_file_name(file) + self._extension,
)
for file in input_files
]
@@ -612,16 +613,12 @@ Source code for graphnet
# Get new, unique index and increment value
if self._num_workers > 1:
with global_index.get_lock(): # type: ignore[name-defined]
- starting_index = global_index.value # type: ignore[name-defined]
- event_nos = np.arange(
- starting_index, starting_index + n_ids, 1
- ).tolist()
+ start_idx = global_index.value # type: ignore[name-defined]
+ event_nos = np.arange(start_idx, start_idx + n_ids, 1).tolist()
global_index.value += n_ids # type: ignore[name-defined]
else:
- starting_index = self._index
- event_nos = np.arange(
- starting_index, starting_index + n_ids, 1
- ).tolist()
+ start_idx = self._index
+ event_nos = np.arange(start_idx, start_idx + n_ids, 1).tolist()
self._index += n_ids
return event_nos
@@ -682,7 +679,7 @@ Source code for graphnet
[docs]
@final
def merge_files(
- self, files: Optional[List[str]] = None, **kwargs: Any
+ self, files: Optional[Union[List[str], str]] = None, **kwargs: Any
) -> None:
"""Merge converted files.
@@ -698,6 +695,10 @@ Source code for graphnet
files_to_merge = self._output_files
elif files is not None:
# Proceed to merge specified by user.
+ if isinstance(files, str):
+ # We shouldn't merge a single file?
+ self.info(f"Got just a single file {files}. Merging skipped.")
+ return
files_to_merge = files
else:
# Raise error
diff --git a/_modules/graphnet/training/labels.html b/_modules/graphnet/training/labels.html
index 5440be6b8..52ab68d0a 100644
--- a/_modules/graphnet/training/labels.html
+++ b/_modules/graphnet/training/labels.html
@@ -461,8 +461,9 @@ Source code for graphnet.tr
def __call__(self, graph: Data) -> torch.tensor:
"""Compute label for `graph`."""
- label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1)
- return label.type(torch.int)
files (Optional
[List
[str
]], default: None
) – Intermediate files to be merged.
files (Union
[List
[str
], str
, None
], default: None
) – Intermediate files to be merged.
kwargs (Any)
keys (Union
[str
, List
[str
], None
], default: None
) – List of keys in I3Frame to be parsed. Defaults to all keys.
exclude_keys (Union
[str
, List
[str
], None
], default: None
) – List of keys in I3Frame to exclude while parsing.
keys (Union
[List
[str
], str
, None
], default: None
) – List of keys in I3Frame to be parsed. Defaults to all keys.
exclude_keys (Union
[List
[str
], str
, None
], default: None
) – List of keys in I3Frame to exclude while parsing.
extractor_name (str)
global_pooling_schemes (Union
[str
, List
[str
], None
], default: None
) – The list global pooling schemes to use.
+
global_pooling_schemes (Union
[List
[str
], str
, None
], default: None
) – The list global pooling schemes to use.
Options are: “min”, “max”, “mean”, and “sum”.
add_global_variables_after_pooling (bool
, default: False
) – Whether to add global variables
after global pooling. The alternative is to added (distribute)
diff --git a/api/graphnet.models.task.task.html b/api/graphnet.models.task.task.html
index d998c9ab7..fd571afab 100644
--- a/api/graphnet.models.task.task.html
+++ b/api/graphnet.models.task.task.html
@@ -804,10 +804,10 @@
loss_function (LossFunction
) – Loss function appropriate to the task.
target_labels (Union
[str
, List
[str
], None
], default: None
) – Name(s) of the quantity/-ies being predicted, used
+
target_labels (Union
[List
[str
], str
, None
], default: None
) – Name(s) of the quantity/-ies being predicted, used
to extract the target tensor(s) from the Data object in
.compute_loss(…).
prediction_labels (Union
[str
, List
[str
], None
], default: None
) – The name(s) of each column that is predicted by
+
prediction_labels (Union
[List
[str
], str
, None
], default: None
) – The name(s) of each column that is predicted by
the model during inference. If not given, the name will auto
matically be set to target_label + _pred.
transform_prediction_and_target (Optional
[Callable
], default: None
) – Optional function to transform