From 15242bdf1dc5b8311d70c54496a46572e9db9562 Mon Sep 17 00:00:00 2001 From: Mads Henrichsen Date: Wed, 17 Jan 2024 20:46:04 +0100 Subject: [PATCH] added code to register new system message --- src/axolotl/cli/preprocess.py | 7 +++++++ src/axolotl/cli/train.py | 7 +++++++ src/axolotl/prompt_strategies/sharegpt.py | 21 ++++++++++++--------- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 2c27095191..61925e251e 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -17,6 +17,7 @@ ) from axolotl.common.cli import PreprocessCliArgs from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.prompt_strategies.sharegpt import register_chatml_template LOG = logging.getLogger("axolotl.cli.preprocess") @@ -32,6 +33,12 @@ def do_cli(config: Path = Path("examples/"), **kwargs): return_remaining_strings=True ) + if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message: + LOG.info( + f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" + ) + register_chatml_template(parsed_cfg.default_system_message) + if not parsed_cfg.dataset_prepared_path: msg = ( Fore.RED diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index e099a1a6da..df7d6b19eb 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -16,6 +16,7 @@ print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs +from axolotl.prompt_strategies.sharegpt import register_chatml_template from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") @@ -32,6 +33,12 @@ def do_cli(config: Path = Path("examples/"), **kwargs): return_remaining_strings=True ) + if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message: + LOG.info( + f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" + ) + register_chatml_template(parsed_cfg.default_system_message) + if parsed_cfg.rl: dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) else: diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index c026889682..15bfee8c47 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -6,16 +6,19 @@ from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompters import ShareGPTPrompterV2 -register_conv_template( - Conversation( - name="chatml", - system_template="<|im_start|>system\n{system_message}", - system_message="You are a helpful assistant.", - roles=["<|im_start|>user", "<|im_start|>assistant"], - sep_style=SeparatorStyle.CHATML, - sep="<|im_end|>", + +def register_chatml_template(system_message=None): + system_message = system_message or "You are a helpful assistant." + register_conv_template( + Conversation( + name="chatml", + system_template="<|im_start|>system\n{system_message}", + system_message=system_message, + roles=["<|im_start|>user", "<|im_start|>assistant"], + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + ) ) -) def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):