Skip to content

Commit

Permalink
Support using timm optimizers for alternative to default
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Nov 22, 2024
1 parent 49eac2f commit f88ba88
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 22 deletions.
69 changes: 52 additions & 17 deletions src/open_clip_train/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import glob
import logging
import os
Expand Down Expand Up @@ -309,22 +310,56 @@ def main(args):
if args.train_data or args.dataset_type == "synthetic":
assert not args.trace, 'Cannot train with traced model'

exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
include = lambda n, p: not exclude(n, p)

named_parameters = list(model.named_parameters())
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]

optimizer = optim.AdamW(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)
opt = getattr(args, 'opt', 'adamw').lower()
if opt.startswith('timm/'):
from timm.optim import create_optimizer_v2
timm_opt = opt.split('timm/')[-1]
opt_kwargs = {}
assert (args.beta1 is None) == (args.beta2 is None), \
'When using timm optimizer, BOTH beta1 and beta2 must be specified (or not specified).'
if args.beta1 is not None:
opt_kwargs['betas'] = (args.beta1, args.beta2)
if args.momentum is not None:
opt_kwargs['momentum'] = args.momentum
optimizer = create_optimizer_v2(
model,
timm_opt,
lr=args.lr,
weight_decay=args.wd,
eps=args.eps,
**opt_kwargs,
)
else:
# If some params are not passed, we use the default values based on model name.
exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
include = lambda n, p: not exclude(n, p)

named_parameters = list(model.named_parameters())
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]

if opt == 'adamw':
optimizer = optim.AdamW(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)
else:
assert False, f'Unknown optimizer {opt}'

if is_master(args):
if is_master(args):
defaults = copy.deepcopy(optimizer.defaults)
defaults['weight_decay'] = args.wd
defaults = ', '.join([f'{k}: {v}' for k, v in defaults.items()])
logging.info(
f'Created {type(optimizer).__name__} ({args.opt}) optimizer: {defaults}'
)

if args.horovod:
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
Expand Down Expand Up @@ -425,7 +460,7 @@ def main(args):

if args.grad_checkpointing and args.distributed:
logging.info('Disabling DDP dynamo optimizer when grad checkpointing enabled.')
# As of now (~PyTorch 2.4/2.5), compile + checkpointing but DDP optimizer must be disabled
# As of now (~PyTorch 2.4/2.5), compile + grad checkpointing work, but DDP optimizer must be disabled
torch._dynamo.config.optimize_ddp = False

model = torch.compile(original_model)
Expand Down
16 changes: 11 additions & 5 deletions src/open_clip_train/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,14 @@ def parse_args(args):
parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
parser.add_argument("--momentum", type=float, default=None, help="Momentum (for timm optimizers).")
parser.add_argument(
"--warmup", type=int, default=10000, help="Number of steps to warmup for."
)
parser.add_argument(
"--opt", type=str, default='adamw',
help="Which optimizer to use. Choices are ['adamw', or any timm optimizer 'timm/{opt_name}']."
)
parser.add_argument(
"--use-bn-sync",
default=False,
Expand Down Expand Up @@ -467,10 +472,11 @@ def parse_args(args):

args = parser.parse_args(args)

# If some params are not passed, we use the default values based on model name.
default_params = get_default_params(args.model)
for name, val in default_params.items():
if getattr(args, name) is None:
setattr(args, name, val)
if 'timm' not in args.opt:
# set default opt params based on model name (only if timm optimizer not used)
default_params = get_default_params(args.model)
for name, val in default_params.items():
if getattr(args, name) is None:
setattr(args, name, val)

return args

0 comments on commit f88ba88

Please sign in to comment.