diff --git a/src/graphnet/models/model.py b/src/graphnet/models/model.py index b4018daaf..ccb964d60 100644 --- a/src/graphnet/models/model.py +++ b/src/graphnet/models/model.py @@ -64,9 +64,13 @@ def load_state_dict( # DEPRECATION UTILITY: REMOVE AT 2.0 LAUNCH # See https://github.com/graphnet-team/graphnet/issues/647 - state_dict = rename_state_dict_entries( + state_dict, state_dict_altered = rename_state_dict_entries( state_dict=state_dict, old_phrase="_gnn", new_phrase="backbone" ) + if state_dict_altered: + self.warn( + "DeprecationWarning: State dicts with `_gnn` entries will be deprecated in GraphNeT 2.0" + ) return super().load_state_dict(state_dict, **kargs) @classmethod diff --git a/src/graphnet/utilities/deprecation_tools.py b/src/graphnet/utilities/deprecation_tools.py index e0d5b6d41..3ba051aba 100644 --- a/src/graphnet/utilities/deprecation_tools.py +++ b/src/graphnet/utilities/deprecation_tools.py @@ -1,12 +1,12 @@ """Utility functions for handling deprecation transitions.""" -from typing import Dict +from typing import Dict, Tuple from copy import deepcopy from torch import Tensor def rename_state_dict_entries( state_dict: Dict[str, Tensor], old_phrase: str, new_phrase: str -) -> Dict[str, Tensor]: +) -> Tuple[Dict[str, Tensor], bool]: """Replace `old_phrase` in state dict fields with `new_phrase`. Returned state dict is a deepcopy of the input. @@ -23,8 +23,11 @@ def rename_state_dict_entries( new_state_dict = deepcopy(state_dict) # Replace old entries in copy + state_dict_altered = False for key in state_dict.keys(): if old_phrase in key: new_key = key.replace(old_phrase, new_phrase) new_state_dict[new_key] = new_state_dict.pop(key) - return new_state_dict + state_dict_altered = True + + return new_state_dict, state_dict_altered