diff --git a/tests/test_optim.py b/tests/test_optim.py index 1ec227248..0d540f66a 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -300,6 +300,8 @@ def test_optim_factory(optimizer): lr = (1e-2,) * 4 if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'): lr = (1e-3,) * 4 + elif optimizer in ('cmars',): + lr = (1e-4,) * 4 try: if not opt_info.second_order: # basic tests don't support second order right now diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index ccf2b34cb..1384224b7 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -572,6 +572,13 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None: has_betas=True, defaults = {'caution': True} ), + OptimInfo( + name='cmars', + opt_class=Mars, + description='Cautious MARS', + has_betas=True, + defaults={'caution': True} + ), OptimInfo( name='cnadamw', opt_class=NAdamW, diff --git a/timm/optim/mars.py b/timm/optim/mars.py index ef1dc4e81..11b1cf206 100644 --- a/timm/optim/mars.py +++ b/timm/optim/mars.py @@ -14,38 +14,50 @@ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: Apache-2.0 import math +from typing import Optional, Tuple import torch from torch.optim.optimizer import Optimizer - -def mars_single_tensor( - p, - grad, - exp_avg, - exp_avg_sq, - lr, - weight_decay, - beta1, - beta2, - last_grad, - eps, - step, - gamma, - mars_type, - is_grad_2d, - optimize_1d, - lr_1d_factor, - betas_1d, +from ._types import ParamsT + + +def _mars_single_tensor_step( + p: torch.Tensor, + grad: torch.Tensor, + exp_avg: torch.Tensor, + exp_avg_sq: torch.Tensor, + lr: float, + weight_decay: float, + beta1: float, + beta2: float, + last_grad: torch.Tensor, + eps: float, + step: int, + gamma: float, + mars_type: str, + is_grad_2d: bool, + optimize_1d: bool, + lr_1d_factor: bool, + betas_1d: Tuple[float, float], + caution: bool, ): - # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para + # optimize_1d ==> use MARS for 1d param, else use AdamW if optimize_1d or is_grad_2d: one_minus_beta1 = 1. - beta1 - c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad) - c_t_norm = torch.norm(c_t) - if c_t_norm > 1.: - c_t = c_t / c_t_norm + if step == 1: + # this is a timm addition, making first step more consistent when no grad history, otherwise tests fail + c_t = grad + else: + c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad) + c_t_norm = torch.norm(c_t) + if c_t_norm > 1.: + c_t = c_t / c_t_norm exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1) + if caution: + mask = (exp_avg * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + exp_avg = exp_avg * mask if mars_type == "adamw": exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2) bias_correction1 = 1.0 - beta1 ** step @@ -64,6 +76,10 @@ def mars_single_tensor( bias_correction1 = 1.0 - beta1_1d ** step bias_correction2 = 1.0 - beta2_1d ** step denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + if caution: + mask = (exp_avg * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + exp_avg = exp_avg * mask update = p * weight_decay + (exp_avg / bias_correction1).div_(denom) p.add_(update, alpha=-(lr * lr_1d_factor)) return exp_avg, exp_avg_sq @@ -78,16 +94,17 @@ class Mars(Optimizer): """ def __init__( self, - params, - lr=3e-3, - betas=(0.9, 0.99), - eps=1e-8, - weight_decay=0., - gamma=0.025, - mars_type="adamw", - optimize_1d=False, - lr_1d_factor=1.0, - betas_1d=None, + params: ParamsT, + lr: float = 3e-3, + betas: Tuple[float, float] = (0.9, 0.99), + eps: float = 1e-8, + weight_decay: float = 0., + gamma: float = 0.025, + mars_type: str = "adamw", + optimize_1d: bool = False, + lr_1d_factor: float = 1.0, + betas_1d: Optional[Tuple[float, float]] = None, + caution: bool = False ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -109,9 +126,15 @@ def __init__( optimize_1d=optimize_1d, lr_1d_factor=lr_1d_factor, betas_1d=betas_1d or betas, + caution=caution, ) super(Mars, self).__init__(params, defaults) + def __setstate__(self, state): + super(Mars, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('caution', False) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -134,7 +157,6 @@ def step(self, closure=None): raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') state = self.state[p] - # ('----- starting a parameter state', state.keys(), 'Length of state', len(state)) # State initialization if len(state) <= 1: state['step'] = 0 @@ -155,7 +177,8 @@ def step(self, closure=None): beta1, beta2 = group['betas'] is_grad_2d = grad.ndim >= 2 - mars_single_tensor( + # FIXME add multi-tensor (if usage warrants), make more standard + _mars_single_tensor_step( p, grad, exp_avg, @@ -173,6 +196,7 @@ def step(self, closure=None): optimize_1d=group['optimize_1d'], lr_1d_factor=group['lr_1d_factor'], betas_1d=group['betas_1d'], + caution=group['caution'], ) state['last_grad'] = grad