Skip to content

Commit

Permalink
Reintroduce save_model_config and save_dataset_config but added depre…
Browse files Browse the repository at this point in the history
…cation warning. To not break backwards compatibility.
  • Loading branch information
AMHermansen committed Sep 14, 2023
1 parent 63a8715 commit 9321927
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/graphnet/utilities/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
DatasetConfig,
DatasetConfigSaverMeta,
DatasetConfigSaverABCMeta,
save_dataset_config,
)
from .model_config import (
ModelConfig,
ModelConfigSaverMeta,
ModelConfigSaverABC,
save_model_config,
)
from .training_config import TrainingConfig
46 changes: 45 additions & 1 deletion src/graphnet/utilities/config/dataset_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Config classes for the `graphnet.data.dataset` module."""
import warnings
from abc import ABCMeta
from functools import wraps
from typing import (
Expand Down Expand Up @@ -178,11 +179,54 @@ def _parse_torch(self, obj: Any) -> Any:
return obj


def save_dataset_config(init_fn: Callable) -> Callable:
"""Save the arguments to `__init__` functions as member `DatasetConfig`."""
warnings.warn(
"Warning: `save_dataset_config` is deprecated. Config saving "
"is now done automatically, for all classes inheriting from Dataset",
DeprecationWarning,
)

def _replace_model_instance_with_config(
obj: Union["Model", Any]
) -> Union[ModelConfig, Any]:
"""Replace `Model` instances in `obj` with their `ModelConfig`."""
from graphnet.models import Model
import torch

if isinstance(obj, Model):
return obj.config

if isinstance(obj, torch.dtype):
return obj.__str__()

else:
return obj

@wraps(init_fn)
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
"""Set `DatasetConfig` after calling `init_fn`."""
# Call wrapped method
ret = init_fn(self, *args, **kwargs)

# Get all argument values, including defaults
cfg = get_all_argument_values(init_fn, *args, **kwargs)

# Handle nested `Model`s, etc.
cfg = traverse_and_apply(cfg, _replace_model_instance_with_config)
# Add `DatasetConfig` as member variables
self._config = DatasetConfig(**cfg)

return ret

return wrapper


class DatasetConfigSaverMeta(type):
"""Metaclass for `DatasetConfig` that saves the config after `__init__`."""

def __call__(cls: Any, *args: Any, **kwargs: Any) -> object:
"""Catch object construction and save config after `__init__`."""
"""Catch object after construction and save config."""

def _replace_model_instance_with_config(
obj: Union["Model", Any]
Expand Down
45 changes: 44 additions & 1 deletion src/graphnet/utilities/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import wraps
import inspect
import re
import warnings
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -249,6 +250,48 @@ def as_dict(self) -> Dict[str, Dict[str, Any]]:
return {self.__class__.__name__: config_dict}


def save_model_config(init_fn: Callable) -> Callable:
"""Save the arguments to `__init__` functions as a member `ModelConfig`."""
warnings.warn(
"Warning: `save_model_config` is deprecated. Config saving is"
"now done automatically for all classes inheriting from Model",
DeprecationWarning,
)

def _replace_model_instance_with_config(
obj: Union["Model", Any]
) -> Union[ModelConfig, Any]:
"""Replace `Model` instances in `obj` with their `ModelConfig`."""
from graphnet.models import Model

if isinstance(obj, Model):
return obj.config
else:
return obj

@wraps(init_fn)
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
"""Set `ModelConfig` after calling `init_fn`."""
# Call wrapped method
ret = init_fn(self, *args, **kwargs)

# Get all argument values, including defaults
cfg = get_all_argument_values(init_fn, *args, **kwargs)

# Handle nested `Model`s, etc.
cfg = traverse_and_apply(cfg, _replace_model_instance_with_config)

# Add `ModelConfig` as member variables
self._config = ModelConfig(
class_name=str(self.__class__.__name__),
arguments=dict(**cfg),
)

return ret

return wrapper


class ModelConfigSaverMeta(type):
"""Metaclass for saving `ModelConfig` to `Model` instances."""

Expand All @@ -275,7 +318,7 @@ def _replace_model_instance_with_config(

# Store config in
created_obj._config = ModelConfig(
class_name=str(created_obj.__class__.__name__),
class_name=str(cls.__name__),
arguments=dict(**cfg),
)
return created_obj
Expand Down

0 comments on commit 9321927

Please sign in to comment.