diff --git a/src/open_clip_train/main.py b/src/open_clip_train/main.py index b3e9b9b50..a53da6d14 100644 --- a/src/open_clip_train/main.py +++ b/src/open_clip_train/main.py @@ -1,3 +1,4 @@ +import copy import glob import logging import os @@ -309,22 +310,56 @@ def main(args): if args.train_data or args.dataset_type == "synthetic": assert not args.trace, 'Cannot train with traced model' - exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n - include = lambda n, p: not exclude(n, p) - - named_parameters = list(model.named_parameters()) - gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] - rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] - - optimizer = optim.AdamW( - [ - {"params": gain_or_bias_params, "weight_decay": 0.}, - {"params": rest_params, "weight_decay": args.wd}, - ], - lr=args.lr, - betas=(args.beta1, args.beta2), - eps=args.eps, - ) + opt = getattr(args, 'opt', 'adamw').lower() + if opt.startswith('timm/'): + from timm.optim import create_optimizer_v2 + timm_opt = opt.split('timm/')[-1] + opt_kwargs = {} + assert (args.beta1 is None) == (args.beta2 is None), \ + 'When using timm optimizer, BOTH beta1 and beta2 must be specified (or not specified).' + if args.beta1 is not None: + opt_kwargs['betas'] = (args.beta1, args.beta2) + if args.momentum is not None: + opt_kwargs['momentum'] = args.momentum + optimizer = create_optimizer_v2( + model, + timm_opt, + lr=args.lr, + weight_decay=args.wd, + eps=args.eps, + **opt_kwargs, + ) + else: + # If some params are not passed, we use the default values based on model name. + exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n + include = lambda n, p: not exclude(n, p) + + named_parameters = list(model.named_parameters()) + gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] + rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] + + if opt == 'adamw': + optimizer = optim.AdamW( + [ + {"params": gain_or_bias_params, "weight_decay": 0.}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + ) + else: + assert False, f'Unknown optimizer {opt}' + + if is_master(args): + if is_master(args): + defaults = copy.deepcopy(optimizer.defaults) + defaults['weight_decay'] = args.wd + defaults = ', '.join([f'{k}: {v}' for k, v in defaults.items()]) + logging.info( + f'Created {type(optimizer).__name__} ({args.opt}) optimizer: {defaults}' + ) + if args.horovod: optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) hvd.broadcast_parameters(model.state_dict(), root_rank=0) @@ -425,7 +460,7 @@ def main(args): if args.grad_checkpointing and args.distributed: logging.info('Disabling DDP dynamo optimizer when grad checkpointing enabled.') - # As of now (~PyTorch 2.4/2.5), compile + checkpointing but DDP optimizer must be disabled + # As of now (~PyTorch 2.4/2.5), compile + grad checkpointing work, but DDP optimizer must be disabled torch._dynamo.config.optimize_ddp = False model = torch.compile(original_model) diff --git a/src/open_clip_train/params.py b/src/open_clip_train/params.py index 2d94b7e21..2cf5ded30 100644 --- a/src/open_clip_train/params.py +++ b/src/open_clip_train/params.py @@ -143,9 +143,14 @@ def parse_args(args): parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") + parser.add_argument("--momentum", type=float, default=None, help="Momentum (for timm optimizers).") parser.add_argument( "--warmup", type=int, default=10000, help="Number of steps to warmup for." ) + parser.add_argument( + "--opt", type=str, default='adamw', + help="Which optimizer to use. Choices are ['adamw', or any timm optimizer 'timm/{opt_name}']." + ) parser.add_argument( "--use-bn-sync", default=False, @@ -467,10 +472,11 @@ def parse_args(args): args = parser.parse_args(args) - # If some params are not passed, we use the default values based on model name. - default_params = get_default_params(args.model) - for name, val in default_params.items(): - if getattr(args, name) is None: - setattr(args, name, val) + if 'timm' not in args.opt: + # set default opt params based on model name (only if timm optimizer not used) + default_params = get_default_params(args.model) + for name, val in default_params.items(): + if getattr(args, name) is None: + setattr(args, name, val) return args