diff --git a/src/graphnet/data/dataset/dataset.py b/src/graphnet/data/dataset/dataset.py index c9355bbfc..4a61cce0d 100644 --- a/src/graphnet/data/dataset/dataset.py +++ b/src/graphnet/data/dataset/dataset.py @@ -85,6 +85,18 @@ def parse_graph_definition(cfg: dict) -> GraphDefinition: return graph_definition +def parse_labels(cfg: dict) -> Dict[str, Label]: + """Construct Label from DatasetConfig.""" + assert cfg["labels"] is not None + + labels = {} + for key in cfg["labels"].keys(): + labels[key] = load_module(cfg["labels"][key]["class_name"])( + **cfg["labels"][key]["arguments"] + ) + return labels + + class Dataset( Logger, Configurable, @@ -131,6 +143,8 @@ def from_config( # type: ignore[override] cfg = source.dict() if cfg["graph_definition"] is not None: cfg["graph_definition"] = parse_graph_definition(cfg) + if cfg["labels"] is not None: + cfg["labels"] = parse_labels(cfg) return source._dataset_class(**cfg) @classmethod @@ -213,6 +227,7 @@ def __init__( loss_weight_column: Optional[str] = None, loss_weight_default_value: Optional[float] = None, seed: Optional[int] = None, + labels: Optional[Dict[str, Any]] = None, ): """Construct Dataset. @@ -259,6 +274,7 @@ def __init__( `"10000 random events ~ event_no % 5 > 0"` or `"20% random events ~ event_no % 5 > 0"`). graph_definition: Method that defines the graph representation. + labels: Dictionary of labels to be added to the dataset. """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -283,6 +299,7 @@ def __init__( self._truth_table = truth_table self._loss_weight_default_value = loss_weight_default_value self._graph_definition = deepcopy(graph_definition) + self._labels = labels if node_truth is not None: assert isinstance(node_truth_table, str) @@ -332,6 +349,10 @@ def __init__( seed=seed, ) + if self._labels is not None: + for key in self._labels.keys(): + self.add_label(self._labels[key]) + # Implementation-specific initialisation. self._init() diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py index 57739b667..42f68dff1 100644 --- a/src/graphnet/utilities/config/dataset_config.py +++ b/src/graphnet/utilities/config/dataset_config.py @@ -55,6 +55,7 @@ class DatasetConfig(BaseConfig): loss_weight_default_value: Optional[float] = None seed: Optional[int] = None graph_definition: Any = None + labels: Optional[Dict[str, Any]] = None def __init__(self, **data: Any) -> None: """Construct `DataConfig`.