diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 92333f4cab..5fd1d8214f 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -38,29 +38,6 @@ IGNORE_INDEX = -100 -class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods - """Callback to save the PEFT adapter""" - - def on_save( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ): - checkpoint_folder = os.path.join( - args.output_dir, - f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", - ) - - peft_model_path = os.path.join(checkpoint_folder, "adapter_model") - kwargs["model"].save_pretrained( - peft_model_path, save_safetensors=args.save_safetensors - ) - - return control - - class SaveBetterTransformerModelCallback( TrainerCallback ): # pylint: disable=too-few-public-methods diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 0aceee5190..2978073873 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -28,7 +28,6 @@ from axolotl.utils.callbacks import ( GPUStatsCallback, SaveBetterTransformerModelCallback, - SavePeftModelCallback, bench_eval_callback_factory, ) from axolotl.utils.collators import DataCollatorForSeq2Seq @@ -637,12 +636,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ ) callbacks.append(early_stop_cb) - if cfg.local_rank == 0 and cfg.adapter in [ - "lora", - "qlora", - ]: # only save in rank 0 - callbacks.append(SavePeftModelCallback) - if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: callbacks.append(SaveBetterTransformerModelCallback)