-
-
Notifications
You must be signed in to change notification settings - Fork 899
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
382 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
base_model: google/gemma-2-2b | ||
model_type: AutoModelForSequenceClassification | ||
tokenizer_type: AutoTokenizer | ||
|
||
load_in_8bit: false | ||
load_in_4bit: false | ||
strict: false | ||
|
||
reward_model: true | ||
chat_template: gemma | ||
datasets: | ||
- path: argilla/distilabel-intel-orca-dpo-pairs | ||
type: bradley_terry.chat_template | ||
val_set_size: 0.0 | ||
output_dir: ./outputs/out | ||
remove_unused_columns: false | ||
|
||
sequence_len: 2048 | ||
sample_packing: false | ||
eval_sample_packing: false | ||
pad_to_sequence_len: true | ||
|
||
wandb_project: | ||
wandb_entity: | ||
wandb_watch: | ||
wandb_name: | ||
wandb_log_model: | ||
|
||
|
||
gradient_accumulation_steps: 4 | ||
micro_batch_size: 2 | ||
num_epochs: 4 | ||
optimizer: adamw_bnb_8bit | ||
lr_scheduler: cosine | ||
learning_rate: 0.0002 | ||
|
||
train_on_inputs: false | ||
group_by_length: false | ||
bf16: true | ||
fp16: | ||
tf32: true | ||
|
||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: | ||
use_reentrant: false | ||
early_stopping_patience: | ||
resume_from_checkpoint: | ||
local_rank: | ||
logging_steps: 1 | ||
xformers_attention: | ||
flash_attention: true | ||
|
||
warmup_ratio: 0.1 | ||
evals_per_epoch: | ||
eval_table_size: | ||
eval_max_new_tokens: 128 | ||
saves_per_epoch: 1 | ||
debug: | ||
deepspeed: | ||
weight_decay: 0.0 | ||
fsdp: | ||
fsdp_config: | ||
special_tokens: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
### example yaml | ||
|
||
```yaml | ||
chat_template: gemma | ||
datasets: | ||
- path: argilla/distilabel-intel-orca-dpo-pairs | ||
type: bradley_terry.chat_template | ||
val_set_size: 0.0 | ||
output_dir: ./outputs/out | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""Module to load prompt strategies.""" | ||
|
||
import importlib | ||
import inspect | ||
import logging | ||
|
||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig | ||
|
||
LOG = logging.getLogger("axolotl.prompt_strategies") | ||
|
||
|
||
def load(strategy, tokenizer, cfg, ds_cfg): | ||
# pylint: disable=duplicate-code | ||
try: | ||
load_fn = "load" | ||
if strategy.split(".")[-1].startswith("load_"): | ||
load_fn = strategy.split(".")[-1] | ||
strategy = ".".join(strategy.split(".")[:-1]) | ||
mod = importlib.import_module( | ||
f".{strategy}", "axolotl.prompt_strategies.bradley_terry" | ||
) | ||
func = getattr(mod, load_fn) | ||
load_kwargs = {} | ||
if strategy == "user_defined": | ||
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg) | ||
else: | ||
sig = inspect.signature(func) | ||
if "ds_cfg" in sig.parameters: | ||
load_kwargs["ds_cfg"] = ds_cfg | ||
return func(tokenizer, cfg, **load_kwargs) | ||
except ModuleNotFoundError: | ||
return None | ||
except Exception as exc: # pylint: disable=broad-exception-caught | ||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") | ||
return None |
88 changes: 88 additions & 0 deletions
88
src/axolotl/prompt_strategies/bradley_terry/chat_template.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
""" | ||
Bradley-Terry model with chat template prompt strategy. | ||
""" | ||
|
||
from typing import Any, Dict, Optional | ||
|
||
from axolotl.prompt_strategies.chat_template import ( | ||
ChatTemplatePrompter, | ||
ChatTemplateStrategy, | ||
) | ||
from axolotl.utils.chat_templates import chat_templates | ||
|
||
|
||
class BTChatTemplateStrategy(ChatTemplateStrategy): | ||
""" | ||
Bradley-Terry reward model pairwise chat template prompt strategy. | ||
""" | ||
|
||
def tokenize_prompt(self, prompt): | ||
""" | ||
:param prompt: the actual row of data from the underlying dataset | ||
:return: | ||
""" | ||
|
||
self.messages = "chosen_messages" | ||
# pylint: disable=duplicate-code | ||
prompt[self.messages] = [] | ||
if prompt["system"]: | ||
prompt[self.messages].append({"from": "system", "value": prompt["system"]}) | ||
prompt[self.messages].append({"from": "user", "value": prompt["input"]}) | ||
prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]}) | ||
chosen_tokenized = super().tokenize_prompt(prompt) | ||
|
||
self.messages = "rejected_messages" | ||
# pylint: disable=duplicate-code | ||
prompt[self.messages] = [] | ||
if prompt["system"]: | ||
prompt[self.messages].append({"from": "system", "value": prompt["system"]}) | ||
prompt[self.messages].append({"from": "user", "value": prompt["input"]}) | ||
prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]}) | ||
rejected_tokenized = super().tokenize_prompt(prompt) | ||
|
||
return { | ||
"input_ids_chosen": chosen_tokenized["input_ids"], | ||
"attention_mask_chosen": chosen_tokenized["attention_mask"], | ||
"labels_chosen": 1.0, | ||
"input_ids_rejected": rejected_tokenized["input_ids"], | ||
"attention_mask_rejected": rejected_tokenized["attention_mask"], | ||
"labels_rejected": 0.0, | ||
} | ||
|
||
|
||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): | ||
ds_cfg = ds_cfg or {} | ||
|
||
prompter_params = { | ||
"tokenizer": tokenizer, | ||
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), | ||
"message_field_role": ds_cfg.get("message_field_role", "from"), | ||
"message_field_content": ds_cfg.get("message_field_content", "value"), | ||
"message_field_training": ds_cfg.get("message_field_training", "training"), | ||
"message_field_training_detail": ds_cfg.get( | ||
"message_field_training_detail", "train_detail" | ||
), | ||
"roles": ds_cfg.get("roles"), | ||
"drop_system_message": ds_cfg.get("drop_system_message", False), | ||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit. | ||
"max_length": cfg.sequence_len + 1 | ||
if not cfg.reward_model | ||
else cfg.sequence_len, | ||
} | ||
|
||
strategy_params = { | ||
"train_on_inputs": cfg.train_on_inputs, | ||
"sequence_len": cfg.sequence_len, | ||
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), | ||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"), | ||
} | ||
|
||
strategy = BTChatTemplateStrategy( | ||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params | ||
) | ||
|
||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"): | ||
strategy.messages = ds_cfg["field_messages"] | ||
|
||
return strategy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
""" | ||
chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template | ||
""" | ||
|
||
|
||
def icr( | ||
cfg, | ||
**kwargs, | ||
): # pylint: disable=possibly-unused-variable,unused-argument | ||
""" | ||
chatml transforms for datasets with system, input, chosen, rejected | ||
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs | ||
""" | ||
|
||
def transform_fn(sample): | ||
if "system" in sample and sample["system"]: | ||
prompt = ( | ||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" | ||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | ||
) | ||
else: | ||
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | ||
sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>" | ||
sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>" | ||
return sample | ||
|
||
return transform_fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.