Skip to content

Commit

Permalink
[config] add gemma-it templates
Browse files Browse the repository at this point in the history
  • Loading branch information
imoneoi committed Mar 9, 2024
1 parent bbc6e51 commit d915544
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions ochat/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
}


_GEMMA_IT_PREFIXES = {
"user": "user",
"assistant": "model"
}


def _v3_2_role_prefix(from_role, condition):
return f"{condition} {_V3_2_PREFIXES[from_role]}".strip()

Expand Down Expand Up @@ -106,9 +112,7 @@ def _v3_2_role_prefix(from_role, condition):
"chatml_mistral": ModelConfig(
# Model
model_max_context=8192,
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
use_fast=False,
legacy=True),
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False),
model_create_for_training=partial(ochat.models.MistralForCausalLM.from_pretrained,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16),
Expand All @@ -122,9 +126,7 @@ def _v3_2_role_prefix(from_role, condition):
"zephyr_mistral": ModelConfig(
# Model
model_max_context=8192,
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
use_fast=False,
legacy=True),
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False),
model_create_for_training=partial(ochat.models.MistralForCausalLM.from_pretrained,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16),
Expand All @@ -135,4 +137,18 @@ def _v3_2_role_prefix(from_role, condition):
eot="</s>",
inference_condition="")
),
"gemma_it": ModelConfig(
# Model
model_max_context=8192,
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False),
model_create_for_training=partial(ochat.models.GemmaForCausalLM.from_pretrained,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16),

# Conversation Template
conversation_template=partial(ConversationTemplate,
role_prefix=lambda from_role, condition: f"<start_of_turn>{_GEMMA_IT_PREFIXES[from_role]}\n",
eot="<end_of_turn>",
inference_condition="")
),
}

0 comments on commit d915544

Please sign in to comment.