diff --git a/src/axolotl/train.py b/src/axolotl/train.py index ebd020061b..32bcbc1d0a 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -212,6 +212,10 @@ def terminate_handler(_, __, model_weakref): if cfg.flash_optimum and BetterTransformer: model = BetterTransformer.reverse(model) + if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: + trainer.model.save_pretrained( + cfg.output_dir, safe_serialization=safe_serialization + ) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) if not cfg.hub_model_id: