diff --git a/llmfoundry/optim/__init__.py b/llmfoundry/optim/__init__.py index 0b55944338..ce93487aef 100644 --- a/llmfoundry/optim/__init__.py +++ b/llmfoundry/optim/__init__.py @@ -10,6 +10,7 @@ from llmfoundry.optim.adaptive_lion import DecoupledAdaLRLion, DecoupledClipLion from llmfoundry.optim.lion import DecoupledLionW +from llmfoundry.optim.no_op import NoOp from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler from llmfoundry.registry import optimizers, schedulers @@ -17,6 +18,7 @@ optimizers.register('clip_lion', func=DecoupledClipLion) optimizers.register('decoupled_lionw', func=DecoupledLionW) optimizers.register('decoupled_adamw', func=DecoupledAdamW) +optimizers.register('no_op', func=NoOp) schedulers.register('constant_with_warmup', func=ConstantWithWarmupScheduler) schedulers.register( @@ -33,5 +35,6 @@ 'DecoupledLionW', 'DecoupledClipLion', 'DecoupledAdaLRLion', + 'NoOp', 'InverseSquareRootWithWarmupScheduler', ] diff --git a/llmfoundry/optim/no_op.py b/llmfoundry/optim/no_op.py new file mode 100644 index 0000000000..f435917b36 --- /dev/null +++ b/llmfoundry/optim/no_op.py @@ -0,0 +1,32 @@ +import torch +from typing import Iterable, Any, Optional, Callable + +class NoOp(torch.optim.Optimizer): + def __init__( + self, + params: Iterable[torch.Tensor], + ): + # LR schedulers expect param groups to have LR. Unused. + defaults = {"lr": 0.0} + super().__init__(params, defaults) + + def __setstate__(self, state: dict[str, dict[Any, Any]]) -> None: + super().__setstate__(state) + + def state_dict(self): + return super().state_dict() + + @torch.no_grad() + def step(self, closure: Optional[Callable] = None): + """Perform no-op optimization step where no parameters are updated. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + return loss \ No newline at end of file diff --git a/tests/optim/test_no_op.py b/tests/optim/test_no_op.py new file mode 100644 index 0000000000..eb1c2fb704 --- /dev/null +++ b/tests/optim/test_no_op.py @@ -0,0 +1,43 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import torch +from typing import Callable +from torch.utils.data import DataLoader +from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM +from composer.trainer import Trainer +from llmfoundry.utils.builders import build_optimizer +import copy + + +def test_no_op_does_nothing( + build_tiny_mpt: Callable[..., ComposerMPTCausalLM], + tiny_ft_dataloader: DataLoader, +): + + # Build MPT model + model = build_tiny_mpt( + loss_fn='torch_crossentropy', + attn_config={ + 'attn_impl': 'torch', + } + ) + + # Build NoOp optimizer + no_op_optim = build_optimizer(model, 'no_op', optimizer_config={}) + + orig_model = copy.deepcopy(model) + + # build trainer + trainer = Trainer( + model=model, + train_dataloader=tiny_ft_dataloader, + max_duration=f'2ba', + optimizers=no_op_optim, + ) + trainer.fit() + + # Check that the model has not changed + for ((orig_name, orig_param), (new_name, new_param)) in zip(orig_model.named_parameters(), model.named_parameters()): + print(f'Checking {orig_name} and {new_name}') + assert torch.equal(orig_param, new_param)