diff --git a/README.md b/README.md index f47beac776..95e5f530cd 100644 --- a/README.md +++ b/README.md @@ -613,6 +613,8 @@ rl: # Saves the desired chat template to the tokenizer_config.json for easier inferencing # Currently supports chatml and inst (mistral/mixtral) chat_template: chatml +# Changes the default system message +default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml. # Axolotl attempts to save the dataset as an arrow after packing the data together so # subsequent training attempts load faster, relative path dataset_prepared_path: data/last_run_prepared diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 745a530c0a..8ea68575db 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -18,6 +18,7 @@ ) from axolotl.common.cli import PreprocessCliArgs from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.prompt_strategies.sharegpt import register_chatml_template LOG = logging.getLogger("axolotl.cli.preprocess") @@ -34,6 +35,12 @@ def do_cli(config: Path = Path("examples/"), **kwargs): return_remaining_strings=True ) + if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message: + LOG.info( + f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" + ) + register_chatml_template(parsed_cfg.default_system_message) + if not parsed_cfg.dataset_prepared_path: msg = ( Fore.RED diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 8dc786c17d..e18f45c338 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -18,6 +18,7 @@ print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs +from axolotl.prompt_strategies.sharegpt import register_chatml_template from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") @@ -37,7 +38,12 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: print_axolotl_text_art() check_accelerate_default_config() check_user_token() - if cfg.rl: + if cfg.chat_template == "chatml" and cfg.default_system_message: + LOG.info( + f"ChatML set. Adding default system message: {cfg.default_system_message}" + ) + register_chatml_template(cfg.default_system_message) + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index c026889682..15bfee8c47 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -6,16 +6,19 @@ from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompters import ShareGPTPrompterV2 -register_conv_template( - Conversation( - name="chatml", - system_template="<|im_start|>system\n{system_message}", - system_message="You are a helpful assistant.", - roles=["<|im_start|>user", "<|im_start|>assistant"], - sep_style=SeparatorStyle.CHATML, - sep="<|im_end|>", + +def register_chatml_template(system_message=None): + system_message = system_message or "You are a helpful assistant." + register_conv_template( + Conversation( + name="chatml", + system_template="<|im_start|>system\n{system_message}", + system_message=system_message, + roles=["<|im_start|>user", "<|im_start|>assistant"], + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + ) ) -) def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 459da44007..2470809d4d 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -20,7 +20,7 @@ def chat_templates(user_choice: str): templates = { "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. - "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", } if user_choice in templates: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c25aa52ed5..7ca9abbb57 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -219,7 +219,13 @@ def load_tokenizer(cfg): LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") if cfg.chat_template: - tokenizer.chat_template = chat_templates(cfg.chat_template) + chat_template_string = chat_templates(cfg.chat_template) + if cfg.default_system_message and cfg.chat_template == "chatml": + chat_template_string = chat_template_string.replace( + "You are a helpful assistant.", cfg.default_system_message + ) + + tokenizer.chat_template = chat_template_string else: LOG.info( "No Chat template selected. Consider adding a chat template for easier inference." diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index ee62ab5d03..19f8217e04 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -7,9 +7,14 @@ from transformers import AutoTokenizer from axolotl.datasets import TokenizedPromptDataset -from axolotl.prompt_strategies.sharegpt import SimpleShareGPTPromptTokenizingStrategy +from axolotl.prompt_strategies.sharegpt import ( + SimpleShareGPTPromptTokenizingStrategy, + register_chatml_template, +) from axolotl.prompters import ShareGPTPrompterV2 +register_chatml_template() + @pytest.fixture(name="sharegpt_dataset") def fixture_sharegpt_dataset():