Skip to content

Commit

Permalink
Allows for adding labels in dataset config
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Jan 22, 2024
1 parent d511f1c commit eb6d974
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/graphnet/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def parse_graph_definition(cfg: dict) -> GraphDefinition:
return graph_definition


def parse_labels(cfg: dict) -> Dict[str, Label]:
"""Construct Label from DatasetConfig."""
assert cfg["labels"] is not None

labels = {}
for key in cfg["labels"].keys():
labels[key] = load_module(cfg["labels"][key]["class_name"])(
**cfg["labels"][key]["arguments"]
)
return labels


class Dataset(
Logger,
Configurable,
Expand Down Expand Up @@ -147,6 +159,8 @@ def from_config( # type: ignore[override]
cfg = source.dict()
if cfg["graph_definition"] is not None:
cfg["graph_definition"] = parse_graph_definition(cfg)
if cfg["labels"] is not None:
cfg["labels"] = parse_labels(cfg)
return source._dataset_class(**cfg)

@classmethod
Expand Down Expand Up @@ -230,6 +244,7 @@ def __init__(
loss_weight_default_value: Optional[float] = None,
seed: Optional[int] = None,
use_cache: bool = True,
labels: Optional[Dict[str, Any]] = None,
):
"""Construct Dataset.
Expand Down Expand Up @@ -277,6 +292,7 @@ def __init__(
events ~ event_no % 5 > 0"`).
graph_definition: Method that defines the graph representation.
use_cache: Whether or not to save indices and selections to a temporary cache for faster initializing.
labels: Dictionary of labels to be added to the dataset.
"""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)
Expand All @@ -301,6 +317,7 @@ def __init__(
self._truth_table = truth_table
self._loss_weight_default_value = loss_weight_default_value
self._graph_definition = deepcopy(graph_definition)
self._labels = labels

if node_truth is not None:
assert isinstance(node_truth_table, str)
Expand Down Expand Up @@ -351,6 +368,9 @@ def __init__(
use_cache=use_cache,
)

if self._labels is not None:
for key in self._labels.keys():
self.add_label(self._labels[key])
# Implementation-specific initialisation.
self._init()

Expand Down
1 change: 1 addition & 0 deletions src/graphnet/utilities/config/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class DatasetConfig(BaseConfig):
seed: Optional[int] = None
graph_definition: Any = None
use_cache: bool = True
labels: Optional[Dict[str, Any]] = None

def __init__(self, **data: Any) -> None:
"""Construct `DataConfig`.
Expand Down

0 comments on commit eb6d974

Please sign in to comment.