From 796a085b2f688f4a5efe249d95f53ff6833bf009 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 8 May 2024 10:39:33 -0400 Subject: [PATCH] make sure to save the lora adapter at the end of RL/dpo training (#1573) --- src/axolotl/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index ebd020061b..32bcbc1d0a 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -212,6 +212,10 @@ def terminate_handler(_, __, model_weakref): if cfg.flash_optimum and BetterTransformer: model = BetterTransformer.reverse(model) + if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: + trainer.model.save_pretrained( + cfg.output_dir, safe_serialization=safe_serialization + ) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) if not cfg.hub_model_id: