From 74db2a1baeb3021f7b93f3a9581d84c843d11835 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Wed, 30 Oct 2024 23:57:00 +0530 Subject: [PATCH] Fix get_chat_template call for trainer builder (#2003) --- src/axolotl/cli/__init__.py | 2 +- src/axolotl/core/trainer_builder.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 52765a9b58..84586ccc37 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -272,7 +272,7 @@ def do_inference_gradio( importlib.import_module("axolotl.prompters"), prompter ) elif cfg.chat_template: - chat_template_str = get_chat_template(cfg.chat_template) + chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer) model = model.to(cfg.device, dtype=cfg.torch_dtype) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d125f838d3..e47c09d514 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1595,7 +1595,8 @@ def build(self, total_num_steps): training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.chat_template: training_arguments_kwargs["chat_template"] = get_chat_template( - self.cfg.chat_template + self.cfg.chat_template, + tokenizer=self.tokenizer, ) if self.cfg.rl == "orpo":