Skip to content

Commit

Permalink
added code to register new system message
Browse files Browse the repository at this point in the history
  • Loading branch information
Mads Henrichsen authored and Mads Henrichsen committed Jan 17, 2024
1 parent 909bd25 commit 15242bd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
7 changes: 7 additions & 0 deletions src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.sharegpt import register_chatml_template

LOG = logging.getLogger("axolotl.cli.preprocess")

Expand All @@ -32,6 +33,12 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
return_remaining_strings=True
)

if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_chatml_template(parsed_cfg.default_system_message)

if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
Expand Down
7 changes: 7 additions & 0 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.prompt_strategies.sharegpt import register_chatml_template
from axolotl.train import train

LOG = logging.getLogger("axolotl.cli.train")
Expand All @@ -32,6 +33,12 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
return_remaining_strings=True
)

if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_chatml_template(parsed_cfg.default_system_message)

if parsed_cfg.rl:
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
Expand Down
21 changes: 12 additions & 9 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2

register_conv_template(
Conversation(
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message="You are a helpful assistant.",
roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",

def register_chatml_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
register_conv_template(
Conversation(
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)
)


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
Expand Down

0 comments on commit 15242bd

Please sign in to comment.