diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 972c291406..af6eaabf2f 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1287,6 +1287,18 @@ def get_post_trainer_create_callbacks(self, trainer): if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: callbacks.append(lisa_callback_factory(trainer)) + + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + callbacks.extend( + [ + cb + for cb in plugin_manager.add_callbacks_post_trainer( + self.cfg, trainer + ) + if cb + ] + ) return callbacks def _get_trainer_cls(self): diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 43afa431a0..a271c59d10 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -140,7 +140,7 @@ def create_lr_scheduler( def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument """ - Adds callbacks to the trainer before training. + setup callbacks before creating the trainer. Parameters: cfg (dict): The configuration for the plugin. @@ -155,14 +155,15 @@ def add_callbacks_post_trainer( self, cfg, trainer ): # pylint: disable=unused-argument """ - Adds callbacks to the trainer after training. + Adds callbacks to the trainer after creating the trainer. + This is useful for callbacks that require access to the model or trainer. Parameters: cfg (dict): The configuration for the plugin. trainer (object): The trainer object for training. Returns: - List[callable]: A list of callback functions to be added to the TrainingArgs + List[callable]: A list of callback functions to be added """ return [] @@ -393,7 +394,9 @@ def add_callbacks_pre_trainer(self, cfg, model): """ callbacks = [] for plugin in self.plugins.values(): - callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model)) + plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model) + if plugin_callbacks: # if the plugin returned a list of callbacks + callbacks.extend(plugin_callbacks) return callbacks def add_callbacks_post_trainer(self, cfg, trainer): @@ -409,7 +412,9 @@ def add_callbacks_post_trainer(self, cfg, trainer): """ callbacks = [] for plugin in self.plugins.values(): - callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer)) + plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer) + if plugin_callbacks: + callbacks.extend(plugin_callbacks) return callbacks def post_train_unload(self, cfg): diff --git a/src/axolotl/integrations/grokfast/LICENSE b/src/axolotl/integrations/grokfast/LICENSE new file mode 100644 index 0000000000..21e35ee968 --- /dev/null +++ b/src/axolotl/integrations/grokfast/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/axolotl/integrations/grokfast/README.md b/src/axolotl/integrations/grokfast/README.md new file mode 100644 index 0000000000..4950dde87a --- /dev/null +++ b/src/axolotl/integrations/grokfast/README.md @@ -0,0 +1,13 @@ +# Grokfast Optimizer + +See https://github.com/ironjr/grokfast + +### Usage + +```yaml +plugins: + - axolotl.integrations.grokfast.GrokfastPlugin + +grokfast_alpha: 2.0 +grokfast_lamb: 0.98 +``` diff --git a/src/axolotl/integrations/grokfast/__init__.py b/src/axolotl/integrations/grokfast/__init__.py new file mode 100644 index 0000000000..3889e927c2 --- /dev/null +++ b/src/axolotl/integrations/grokfast/__init__.py @@ -0,0 +1,50 @@ +""" +Grokfast plugin for Axolotl +""" +import logging + +from transformers.trainer_callback import TrainerCallback + +from ..base import BasePlugin +from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401 +from .optimizer import gradfilter_ema + +LOG = logging.getLogger("axolotl.integrations.grokfast") + + +class GrokfastCallbackHandler(TrainerCallback): + """ + Transformer trainer callbacks for Grokfast + """ + + def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs): + super().__init__(*args_, **kwargs) + self.grads = None + self.alpha = alpha + self.lamb = lamb + + def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument + self.grads = None + + def on_pre_optimizer_step( + self, args_, state, control, **kwargs + ): # pylint: disable=unused-argument + model = kwargs.pop("model") + self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb) + return control + + +class GrokfastPlugin(BasePlugin): + """ + Plugin for Grokfast optimizer integraton with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.grokfast.GrokfastArgs" + + def add_callbacks_post_trainer(self, cfg, trainer): + LOG.info("Adding Grokfast callback to the trainer") + callback = GrokfastCallbackHandler( + alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb + ) + return [callback] diff --git a/src/axolotl/integrations/grokfast/args.py b/src/axolotl/integrations/grokfast/args.py new file mode 100644 index 0000000000..4776ae60ca --- /dev/null +++ b/src/axolotl/integrations/grokfast/args.py @@ -0,0 +1,15 @@ +""" +config args for grokfast plugin +""" +from typing import Optional + +from pydantic import BaseModel + + +class GrokfastArgs(BaseModel): + """ + Input args for Grokfast optimizer. + """ + + grokfast_alpha: Optional[float] = 0.98 + grokfast_lamb: Optional[float] = 2.0 diff --git a/src/axolotl/integrations/grokfast/optimizer.py b/src/axolotl/integrations/grokfast/optimizer.py new file mode 100644 index 0000000000..38cda2c934 --- /dev/null +++ b/src/axolotl/integrations/grokfast/optimizer.py @@ -0,0 +1,63 @@ +# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee +# Reference: https://github.com/ironjr/grokfast + +# pylint: skip-file +from collections import deque +from typing import Dict, Literal, Optional + +import torch +import torch.nn as nn + + +def gradfilter_ma( + m: nn.Module, + grads: Optional[Dict[str, deque]] = None, + window_size: int = 100, + lamb: float = 5.0, + filter_type: Literal["mean", "sum"] = "mean", + warmup: bool = True, + trigger: bool = False, # For ablation study. +) -> Dict[str, deque]: + if grads is None: + grads = { + n: deque(maxlen=window_size) + for n, p in m.named_parameters() + if p.requires_grad and p.grad is not None + } + + for n, p in m.named_parameters(): + if p.requires_grad and p.grad is not None: + grads[n].append(p.grad.data.detach()) # .cpu()) + + # Modify the gradients. + if not warmup or len(grads[n]) == window_size and not trigger: + if filter_type == "mean": + avg = sum(grads[n]) / len(grads[n]) + elif filter_type == "sum": + avg = sum(grads[n]) + else: + raise ValueError(f"Unrecognized filter_type {filter_type}") + p.grad.data = p.grad.data + avg * lamb + + return grads + + +def gradfilter_ema( + m: nn.Module, + grads: Optional[Dict[str, torch.Tensor]] = None, + alpha: float = 0.98, + lamb: float = 2.0, +) -> Dict[str, torch.Tensor]: + if grads is None: + grads = { + n: p.grad.data.detach() + for n, p in m.named_parameters() + if p.requires_grad and p.grad is not None + } + + for n, p in m.named_parameters(): + if p.requires_grad and p.grad is not None: + grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha) + p.grad.data = p.grad.data + grads[n] * lamb + + return grads diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 5066231d99..1feb8aae86 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -783,6 +783,8 @@ class Config: is_mistral_derived_model: Optional[bool] = Field(default=None) is_qwen_derived_model: Optional[bool] = Field(default=None) + plugins: Optional[List[str]] = Field(default=None) + @field_validator("datasets", mode="before") @classmethod def deprecate_sharegpt_datasets(cls, datasets):