Skip to content

Commit

Permalink
Add plugin manager's callback hooks to training flow (#2006)
Browse files Browse the repository at this point in the history
* Add plugin manager's callback hooks to training flow

* Use .values() instead of .items()
  • Loading branch information
chiragjn authored Oct 31, 2024
1 parent 5c7e891 commit d4dbfa0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 30 deletions.
23 changes: 18 additions & 5 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length

from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
Expand Down Expand Up @@ -1147,6 +1148,12 @@ def build(self, total_num_steps):

def get_callbacks(self) -> List[TrainerCallback]:
callbacks = []

plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)

if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
Expand All @@ -1173,11 +1180,17 @@ def get_callbacks(self) -> List[TrainerCallback]:

return callbacks

@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []

plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
)
return callbacks

def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
Expand Down Expand Up @@ -1223,7 +1236,7 @@ def get_callbacks(self):
return callbacks

def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
Expand Down Expand Up @@ -1791,7 +1804,7 @@ def get_callbacks(self):
return callbacks

def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks

def build_training_arguments(self, total_num_steps):
Expand Down Expand Up @@ -2000,11 +2013,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
"""

def get_callbacks(self):
callbacks = []
callbacks = super().get_callbacks()
return callbacks

def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks

def build(self, total_num_steps):
Expand Down
57 changes: 32 additions & 25 deletions src/axolotl/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
import collections
import importlib
import logging
from typing import List
from typing import OrderedDict


class BasePlugin:
Expand All @@ -47,7 +48,7 @@ def __init__(self):
Initializes the BasePlugin.
"""

def register(self, cfg):
def register(self, cfg): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
Expand All @@ -63,7 +64,7 @@ def get_input_args(self):
Returns a pydantic model for the plugin's input arguments.
"""

def pre_model_load(self, cfg):
def pre_model_load(self, cfg): # pylint: disable=unused-argument
"""
Performs actions before the model is loaded.
Expand All @@ -74,7 +75,7 @@ def pre_model_load(self, cfg):
None
"""

def post_model_load(self, cfg, model):
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is loaded.
Expand All @@ -86,7 +87,7 @@ def post_model_load(self, cfg, model):
None
"""

def pre_lora_load(self, cfg, model):
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
Expand All @@ -98,7 +99,7 @@ def pre_lora_load(self, cfg, model):
None
"""

def post_lora_load(self, cfg, model):
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
Expand All @@ -110,7 +111,7 @@ def post_lora_load(self, cfg, model):
None
"""

def create_optimizer(self, cfg, trainer):
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
Expand All @@ -122,7 +123,9 @@ def create_optimizer(self, cfg, trainer):
object: The created optimizer.
"""

def create_lr_scheduler(self, cfg, trainer, optimizer):
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
"""
Creates and returns a learning rate scheduler.
Expand All @@ -135,7 +138,7 @@ def create_lr_scheduler(self, cfg, trainer, optimizer):
object: The created learning rate scheduler.
"""

def add_callbacks_pre_trainer(self, cfg, model):
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer before training.
Expand All @@ -146,8 +149,11 @@ def add_callbacks_pre_trainer(self, cfg, model):
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []

def add_callbacks_post_trainer(self, cfg, trainer):
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer after training.
Expand All @@ -158,8 +164,9 @@ def add_callbacks_post_trainer(self, cfg, trainer):
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []

def post_train(self, cfg, model):
def post_train(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after training is complete.
Expand All @@ -171,7 +178,7 @@ def post_train(self, cfg, model):
None
"""

def post_train_unload(self, cfg):
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
Expand Down Expand Up @@ -227,7 +234,7 @@ class PluginManager:
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""

plugins: List[BasePlugin] = []
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()

_instance = None

Expand All @@ -237,7 +244,7 @@ def __new__(cls):
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins: List[BasePlugin] = []
cls._instance.plugins = collections.OrderedDict()
return cls._instance

@staticmethod
Expand Down Expand Up @@ -265,7 +272,7 @@ def register(self, plugin_name: str):
"""
try:
plugin = load_plugin(plugin_name)
self.plugins.append(plugin)
self.plugins[plugin_name] = plugin
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")

Expand All @@ -277,7 +284,7 @@ def get_input_args(self):
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins:
for plugin in self.plugins.values():
input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin)
Expand All @@ -293,7 +300,7 @@ def pre_model_load(self, cfg):
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)

def post_model_load(self, cfg, model):
Expand All @@ -307,7 +314,7 @@ def post_model_load(self, cfg, model):
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_model_load(cfg, model)

def pre_lora_load(self, cfg, model):
Expand All @@ -321,7 +328,7 @@ def pre_lora_load(self, cfg, model):
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.pre_lora_load(cfg, model)

def post_lora_load(self, cfg, model):
Expand All @@ -335,7 +342,7 @@ def post_lora_load(self, cfg, model):
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)

def create_optimizer(self, cfg, trainer):
Expand All @@ -349,7 +356,7 @@ def create_optimizer(self, cfg, trainer):
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None:
return optimizer
Expand All @@ -367,7 +374,7 @@ def create_lr_scheduler(self, cfg, trainer, optimizer):
Returns:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None:
return scheduler
Expand All @@ -385,7 +392,7 @@ def add_callbacks_pre_trainer(self, cfg, model):
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
for plugin in self.plugins.values():
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
return callbacks

Expand All @@ -401,7 +408,7 @@ def add_callbacks_post_trainer(self, cfg, trainer):
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
for plugin in self.plugins.values():
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks

Expand All @@ -416,5 +423,5 @@ def post_train_unload(self, cfg):
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)

0 comments on commit d4dbfa0

Please sign in to comment.