From d915544bdb5055053be989ea525312a8b5b1c505 Mon Sep 17 00:00:00 2001 From: One Date: Sat, 9 Mar 2024 01:53:13 +0000 Subject: [PATCH] [config] add gemma-it templates --- ochat/config/__init__.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/ochat/config/__init__.py b/ochat/config/__init__.py index 6109e50..8c9497a 100644 --- a/ochat/config/__init__.py +++ b/ochat/config/__init__.py @@ -16,6 +16,12 @@ } +_GEMMA_IT_PREFIXES = { + "user": "user", + "assistant": "model" +} + + def _v3_2_role_prefix(from_role, condition): return f"{condition} {_V3_2_PREFIXES[from_role]}".strip() @@ -106,9 +112,7 @@ def _v3_2_role_prefix(from_role, condition): "chatml_mistral": ModelConfig( # Model model_max_context=8192, - model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, - use_fast=False, - legacy=True), + model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False), model_create_for_training=partial(ochat.models.MistralForCausalLM.from_pretrained, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16), @@ -122,9 +126,7 @@ def _v3_2_role_prefix(from_role, condition): "zephyr_mistral": ModelConfig( # Model model_max_context=8192, - model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, - use_fast=False, - legacy=True), + model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False), model_create_for_training=partial(ochat.models.MistralForCausalLM.from_pretrained, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16), @@ -135,4 +137,18 @@ def _v3_2_role_prefix(from_role, condition): eot="", inference_condition="") ), + "gemma_it": ModelConfig( + # Model + model_max_context=8192, + model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False), + model_create_for_training=partial(ochat.models.GemmaForCausalLM.from_pretrained, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16), + + # Conversation Template + conversation_template=partial(ConversationTemplate, + role_prefix=lambda from_role, condition: f"{_GEMMA_IT_PREFIXES[from_role]}\n", + eot="", + inference_condition="") + ), }