diff --git a/trojanzoo/models.py b/trojanzoo/models.py index 9f5d6acc..97afc2ae 100644 --- a/trojanzoo/models.py +++ b/trojanzoo/models.py @@ -957,7 +957,7 @@ def get_official_weights(self, weights: WeightsEnum | None = None, OrderedDict[str, torch.Tensor]: The model weights OrderedDict. """ if weights is None: - weights = getattr(self.weights, self.name) + weights = getattr(self.weights[self.name], 'DEFAULT') return weights.get_state_dict(progress=progress, check_hash=True, map_location=map_location, **kwargs) diff --git a/trojanzoo/utils/module/param.py b/trojanzoo/utils/module/param.py index 04050e4e..725afafb 100644 --- a/trojanzoo/utils/module/param.py +++ b/trojanzoo/utils/module/param.py @@ -99,7 +99,7 @@ def items(self): return self.__data.items() def __getattr__(self, name: str) -> _VT: - if '__data' in name: + if name.startswith('__'): return super().__getattr__(name) return self.__data[name]