Skip to content

Commit

Permalink
yo
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Sep 30, 2024
1 parent 36cc16a commit 311f92b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 0 deletions.
3 changes: 3 additions & 0 deletions llmfoundry/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

from llmfoundry.optim.adaptive_lion import DecoupledAdaLRLion, DecoupledClipLion
from llmfoundry.optim.lion import DecoupledLionW
from llmfoundry.optim.no_op import NoOp
from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler
from llmfoundry.registry import optimizers, schedulers

optimizers.register('adalr_lion', func=DecoupledAdaLRLion)
optimizers.register('clip_lion', func=DecoupledClipLion)
optimizers.register('decoupled_lionw', func=DecoupledLionW)
optimizers.register('decoupled_adamw', func=DecoupledAdamW)
optimizers.register('no_op', func=NoOp)

schedulers.register('constant_with_warmup', func=ConstantWithWarmupScheduler)
schedulers.register(
Expand All @@ -33,5 +35,6 @@
'DecoupledLionW',
'DecoupledClipLion',
'DecoupledAdaLRLion',
'NoOp',
'InverseSquareRootWithWarmupScheduler',
]
32 changes: 32 additions & 0 deletions llmfoundry/optim/no_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
from typing import Iterable, Any, Optional, Callable

class NoOp(torch.optim.Optimizer):
def __init__(
self,
params: Iterable[torch.Tensor],
):
# LR schedulers expect param groups to have LR. Unused.
defaults = {"lr": 0.0}
super().__init__(params, defaults)

def __setstate__(self, state: dict[str, dict[Any, Any]]) -> None:
super().__setstate__(state)

def state_dict(self):
return super().state_dict()

@torch.no_grad()
def step(self, closure: Optional[Callable] = None):
"""Perform no-op optimization step where no parameters are updated.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

return loss
43 changes: 43 additions & 0 deletions tests/optim/test_no_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import torch
from typing import Callable
from torch.utils.data import DataLoader
from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM
from composer.trainer import Trainer
from llmfoundry.utils.builders import build_optimizer
import copy


def test_no_op_does_nothing(
build_tiny_mpt: Callable[..., ComposerMPTCausalLM],
tiny_ft_dataloader: DataLoader,
):

# Build MPT model
model = build_tiny_mpt(
loss_fn='torch_crossentropy',
attn_config={
'attn_impl': 'torch',
}
)

# Build NoOp optimizer
no_op_optim = build_optimizer(model, 'no_op', optimizer_config={})

orig_model = copy.deepcopy(model)

# build trainer
trainer = Trainer(
model=model,
train_dataloader=tiny_ft_dataloader,
max_duration=f'2ba',
optimizers=no_op_optim,
)
trainer.fit()

# Check that the model has not changed
for ((orig_name, orig_param), (new_name, new_param)) in zip(orig_model.named_parameters(), model.named_parameters()):
print(f'Checking {orig_name} and {new_name}')
assert torch.equal(orig_param, new_param)

0 comments on commit 311f92b

Please sign in to comment.