Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config files refactor, Examples polish, Perturbation in GraphDefinition #603

Merged
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a3b722e
Create MetaClasses to save Model/Dataset configs
AMHermansen Sep 13, 2023
682213c
Remove usage of save_model_config and save_dataset_config
AMHermansen Sep 13, 2023
e4c06fd
Remove redundant config saving baseclasses.
AMHermansen Sep 13, 2023
63a8715
Remove redundant config saving baseclasses.
AMHermansen Sep 13, 2023
9321927
Reintroduce save_model_config and save_dataset_config but added depre…
AMHermansen Sep 14, 2023
4af2872
Fixed typehints for make_(train_validation)_dataloader
AMHermansen Sep 19, 2023
368b9a8
Fixed typehints for make_(train_validation)_dataloader
AMHermansen Sep 19, 2023
396eb77
Update utils.py
AMHermansen Sep 19, 2023
0b73f6f
Merge branch 'graphnet-team:main' into add-ConfSaverMeta
AMHermansen Sep 20, 2023
b846b40
fix example 02-02
RasmusOrsoe Sep 22, 2023
ae226e0
default arguments, fix 02-01
RasmusOrsoe Sep 22, 2023
8e04af2
tito_example update
RasmusOrsoe Sep 22, 2023
15a14f7
Polish examples
RasmusOrsoe Sep 22, 2023
a11014c
delete shell script example
RasmusOrsoe Sep 22, 2023
c840b44
rename examples, update readme.md
RasmusOrsoe Sep 22, 2023
1fbf534
Move perturbations to graph_definition
RasmusOrsoe Sep 22, 2023
d4e166a
minor adjustments, unit test
RasmusOrsoe Sep 22, 2023
e93f514
Unit tests
RasmusOrsoe Sep 22, 2023
b567581
delete perturbedsqlitedataset
RasmusOrsoe Sep 22, 2023
8c54c77
remove old import statements
RasmusOrsoe Sep 22, 2023
6f014fa
replace np.float with float
RasmusOrsoe Sep 22, 2023
726e653
shorten warning
RasmusOrsoe Sep 22, 2023
822c0d7
shorten doc string
RasmusOrsoe Sep 22, 2023
1758b74
shorten error strings
RasmusOrsoe Sep 22, 2023
afeec3e
Replace GenericExtractor in 01-03 for FeatureExtractor
RasmusOrsoe Sep 22, 2023
999e26d
fix typo in readme.md
RasmusOrsoe Sep 22, 2023
03dc5d2
Update code comment in 04-01
RasmusOrsoe Sep 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Reintroduce save_model_config and save_dataset_config but added depre…
…cation warning. To not break backwards compatibility.
AMHermansen committed Sep 14, 2023
commit 9321927ee658225fb76b556849aac88555764879
2 changes: 2 additions & 0 deletions src/graphnet/utilities/config/__init__.py
Original file line number Diff line number Diff line change
@@ -5,10 +5,12 @@
DatasetConfig,
DatasetConfigSaverMeta,
DatasetConfigSaverABCMeta,
save_dataset_config,
)
from .model_config import (
ModelConfig,
ModelConfigSaverMeta,
ModelConfigSaverABC,
save_model_config,
)
from .training_config import TrainingConfig
46 changes: 45 additions & 1 deletion src/graphnet/utilities/config/dataset_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Config classes for the `graphnet.data.dataset` module."""
import warnings
from abc import ABCMeta
from functools import wraps
from typing import (
@@ -178,11 +179,54 @@ def _parse_torch(self, obj: Any) -> Any:
return obj


def save_dataset_config(init_fn: Callable) -> Callable:
"""Save the arguments to `__init__` functions as member `DatasetConfig`."""
warnings.warn(
"Warning: `save_dataset_config` is deprecated. Config saving "
"is now done automatically, for all classes inheriting from Dataset",
DeprecationWarning,
)

def _replace_model_instance_with_config(
obj: Union["Model", Any]
) -> Union[ModelConfig, Any]:
"""Replace `Model` instances in `obj` with their `ModelConfig`."""
from graphnet.models import Model
import torch

if isinstance(obj, Model):
return obj.config

if isinstance(obj, torch.dtype):
return obj.__str__()

else:
return obj

@wraps(init_fn)
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
"""Set `DatasetConfig` after calling `init_fn`."""
# Call wrapped method
ret = init_fn(self, *args, **kwargs)

# Get all argument values, including defaults
cfg = get_all_argument_values(init_fn, *args, **kwargs)

# Handle nested `Model`s, etc.
cfg = traverse_and_apply(cfg, _replace_model_instance_with_config)
# Add `DatasetConfig` as member variables
self._config = DatasetConfig(**cfg)

return ret

return wrapper


class DatasetConfigSaverMeta(type):
"""Metaclass for `DatasetConfig` that saves the config after `__init__`."""

def __call__(cls: Any, *args: Any, **kwargs: Any) -> object:
"""Catch object construction and save config after `__init__`."""
"""Catch object after construction and save config."""

def _replace_model_instance_with_config(
obj: Union["Model", Any]
45 changes: 44 additions & 1 deletion src/graphnet/utilities/config/model_config.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from functools import wraps
import inspect
import re
import warnings
from typing import (
TYPE_CHECKING,
Any,
@@ -249,6 +250,48 @@ def as_dict(self) -> Dict[str, Dict[str, Any]]:
return {self.__class__.__name__: config_dict}


def save_model_config(init_fn: Callable) -> Callable:
"""Save the arguments to `__init__` functions as a member `ModelConfig`."""
warnings.warn(
"Warning: `save_model_config` is deprecated. Config saving is"
"now done automatically for all classes inheriting from Model",
DeprecationWarning,
)

def _replace_model_instance_with_config(
obj: Union["Model", Any]
) -> Union[ModelConfig, Any]:
"""Replace `Model` instances in `obj` with their `ModelConfig`."""
from graphnet.models import Model

if isinstance(obj, Model):
return obj.config
else:
return obj

@wraps(init_fn)
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
"""Set `ModelConfig` after calling `init_fn`."""
# Call wrapped method
ret = init_fn(self, *args, **kwargs)

# Get all argument values, including defaults
cfg = get_all_argument_values(init_fn, *args, **kwargs)

# Handle nested `Model`s, etc.
cfg = traverse_and_apply(cfg, _replace_model_instance_with_config)

# Add `ModelConfig` as member variables
self._config = ModelConfig(
class_name=str(self.__class__.__name__),
arguments=dict(**cfg),
)

return ret

return wrapper


class ModelConfigSaverMeta(type):
"""Metaclass for saving `ModelConfig` to `Model` instances."""

@@ -275,7 +318,7 @@ def _replace_model_instance_with_config(

# Store config in
created_obj._config = ModelConfig(
class_name=str(created_obj.__class__.__name__),
class_name=str(cls.__name__),
arguments=dict(**cfg),
)
return created_obj