Skip to content

Commit

Permalink
add parsing of labels
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Feb 5, 2024
1 parent b686dd7 commit 5779255
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/graphnet/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def parse_graph_definition(cfg: dict) -> GraphDefinition:
return graph_definition


def parse_labels(cfg: dict) -> Dict[str, Label]:
"""Construct Label from DatasetConfig."""
assert cfg["labels"] is not None

labels = {}
for key in cfg["labels"].keys():
labels[key] = load_module(cfg["labels"][key]["class_name"])(
**cfg["labels"][key]["arguments"]
)
return labels


class Dataset(
Logger,
Configurable,
Expand Down Expand Up @@ -131,6 +143,8 @@ def from_config( # type: ignore[override]
cfg = source.dict()
if cfg["graph_definition"] is not None:
cfg["graph_definition"] = parse_graph_definition(cfg)
if cfg["labels"] is not None:
cfg["labels"] = parse_labels(cfg)
return source._dataset_class(**cfg)

@classmethod
Expand Down Expand Up @@ -213,6 +227,7 @@ def __init__(
loss_weight_column: Optional[str] = None,
loss_weight_default_value: Optional[float] = None,
seed: Optional[int] = None,
labels: Optional[Dict[str, Any]] = None,
):
"""Construct Dataset.
Expand Down Expand Up @@ -259,6 +274,7 @@ def __init__(
`"10000 random events ~ event_no % 5 > 0"` or `"20% random
events ~ event_no % 5 > 0"`).
graph_definition: Method that defines the graph representation.
labels: Dictionary of labels to be added to the dataset.
"""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)
Expand All @@ -283,6 +299,7 @@ def __init__(
self._truth_table = truth_table
self._loss_weight_default_value = loss_weight_default_value
self._graph_definition = deepcopy(graph_definition)
self._labels = labels

if node_truth is not None:
assert isinstance(node_truth_table, str)
Expand Down Expand Up @@ -332,6 +349,10 @@ def __init__(
seed=seed,
)

if self._labels is not None:
for key in self._labels.keys():
self.add_label(self._labels[key])

# Implementation-specific initialisation.
self._init()

Expand Down
1 change: 1 addition & 0 deletions src/graphnet/utilities/config/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class DatasetConfig(BaseConfig):
loss_weight_default_value: Optional[float] = None
seed: Optional[int] = None
graph_definition: Any = None
labels: Optional[Dict[str, Any]] = None

def __init__(self, **data: Any) -> None:
"""Construct `DataConfig`.
Expand Down

0 comments on commit 5779255

Please sign in to comment.