From 68a26f1005e237712cad2e1e970486b6d63aca3e Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Thu, 21 Nov 2024 00:36:08 +0530 Subject: [PATCH] Fix duplication of plugin callbacks (#2090) --- src/axolotl/core/trainer_builder.py | 30 +++++++++++++---------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 14690580d7..75219a2749 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1212,11 +1212,17 @@ def get_post_trainer_create_callbacks(self, trainer): Callbacks added after the trainer is created, usually b/c these need access to the trainer """ callbacks = [] - - plugin_manager = PluginManager.get_instance() - callbacks.extend( - plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer) - ) + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + callbacks.extend( + [ + cb + for cb in plugin_manager.add_callbacks_post_trainer( + self.cfg, trainer + ) + if cb + ] + ) return callbacks def hook_pre_create_training_args(self, training_arguments_kwargs): @@ -1263,7 +1269,7 @@ def get_callbacks(self): return callbacks def get_post_trainer_create_callbacks(self, trainer): - callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) + callbacks = [] if self.cfg.use_wandb and self.cfg.eval_table_size > 0: LogPredictionCallback = log_prediction_callback_factory( trainer, self.tokenizer, "wandb" @@ -1301,17 +1307,7 @@ def get_post_trainer_create_callbacks(self, trainer): if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: callbacks.append(lisa_callback_factory(trainer)) - if self.cfg.plugins: - plugin_manager = PluginManager.get_instance() - callbacks.extend( - [ - cb - for cb in plugin_manager.add_callbacks_post_trainer( - self.cfg, trainer - ) - if cb - ] - ) + callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer)) return callbacks def _get_trainer_cls(self):