diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index eee5f45b1..3727df1dc 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -21,8 +21,13 @@ from .sgdp import SGDP from .sgdw import SGDW -# bring torch optim into timm.optim namespace for consistency -from torch.optim import Adadelta, Adagrad, Adamax, Adam, NAdam, RAdam, RMSprop, SGD +# bring common torch.optim Optimizers into timm.optim namespace for consistency +from torch.optim import Adadelta, Adagrad, Adamax, Adam, AdamW, RMSprop, SGD +try: + # in case any very old torch versions being used + from torch.optim import NAdam, RAdam +except ImportError: + pass from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \ create_optimizer_v2, create_optimizer, optimizer_kwargs