-
-
Notifications
You must be signed in to change notification settings - Fork 902
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a
chat_template
prompt strategy for DPO (#1725)
* Implementing a basic chat_template strategy for DPO datasets This mimics the sft chat_template strategy such that users can: * Specify the messages field * Specify the per message role and content fields * speicfy the chosen and rejected fields * Let the tokenizer construct the raw prompt * Ensure the chosen and rejected fields don't have any prefix tokens * Adding additional dpo chat template unittests * Rename test class
- Loading branch information
1 parent
fa91b69
commit 985819d
Showing
5 changed files
with
317 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
base_model: meta-llama/Meta-Llama-3-8B-Instruct | ||
model_type: LlamaForCausalLM | ||
tokenizer_type: AutoTokenizer | ||
|
||
load_in_8bit: true | ||
load_in_4bit: false | ||
strict: false | ||
|
||
chat_template: llama3 | ||
rl: dpo | ||
datasets: | ||
- path: fozziethebeat/alpaca_messages_2k_dpo_test | ||
type: chat_template.default | ||
chat_template: llama3 | ||
field_messages: conversation | ||
field_chosen: chosen | ||
field_rejected: rejected | ||
message_field_role: role | ||
message_field_content: content | ||
roles: | ||
system: | ||
- system | ||
user: | ||
- user | ||
assistant: | ||
- assistant | ||
|
||
dataset_prepared_path: | ||
val_set_size: 0.05 | ||
output_dir: ./outputs/lora-out | ||
|
||
sequence_len: 4096 | ||
sample_packing: false | ||
pad_to_sequence_len: true | ||
|
||
adapter: lora | ||
lora_model_dir: | ||
lora_r: 32 | ||
lora_alpha: 16 | ||
lora_dropout: 0.05 | ||
lora_target_linear: true | ||
lora_fan_in_fan_out: | ||
|
||
wandb_project: | ||
wandb_entity: | ||
wandb_watch: | ||
wandb_name: | ||
wandb_log_model: | ||
|
||
gradient_accumulation_steps: 4 | ||
micro_batch_size: 2 | ||
num_epochs: 4 | ||
optimizer: adamw_bnb_8bit | ||
lr_scheduler: cosine | ||
learning_rate: 0.0002 | ||
|
||
train_on_inputs: false | ||
group_by_length: false | ||
bf16: auto | ||
fp16: | ||
tf32: false | ||
|
||
gradient_checkpointing: true | ||
early_stopping_patience: | ||
resume_from_checkpoint: | ||
local_rank: | ||
logging_steps: 1 | ||
xformers_attention: | ||
flash_attention: true | ||
s2_attention: | ||
|
||
warmup_steps: 10 | ||
evals_per_epoch: 4 | ||
eval_table_size: | ||
eval_max_new_tokens: 128 | ||
saves_per_epoch: 1 | ||
debug: | ||
deepspeed: | ||
weight_decay: 0.0 | ||
fsdp: | ||
fsdp_config: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
""" | ||
DPO prompt strategies for using tokenizer chat templates. | ||
""" | ||
|
||
from axolotl.utils.chat_templates import chat_templates | ||
|
||
|
||
def default( | ||
cfg, dataset_idx=0, **kwargs | ||
): # pylint: disable=possibly-unused-variable,unused-argument | ||
ds_cfg = cfg["datasets"][dataset_idx] | ||
chat_template_str = chat_templates(cfg.chat_template) | ||
|
||
field_messages = ds_cfg.get("field_messages", "messages") | ||
field_chosen = ds_cfg.get("field_chosen", "chosen") | ||
field_rejected = ds_cfg.get("field_rejected", "rejected") | ||
field_message_role = ds_cfg.get("message_field_role", "role") | ||
field_message_content = ds_cfg.get("message_field_content", "content") | ||
role_map_inv = ds_cfg.get( | ||
"roles", | ||
{ | ||
"user": ["user"], | ||
"assistant": ["assistant"], | ||
"system": ["system"], | ||
}, | ||
) | ||
role_map = {} | ||
for target, sources in role_map_inv.items(): | ||
for source in sources: | ||
role_map[source] = target | ||
|
||
def transform_fn(sample, tokenizer=None): | ||
messages = sample[field_messages] | ||
messages = [ | ||
{ | ||
"role": role_map[m[field_message_role]], | ||
"content": m[field_message_content], | ||
} | ||
for m in messages | ||
] | ||
chosen = { | ||
"role": role_map[sample[field_chosen][field_message_role]], | ||
"content": sample[field_chosen][field_message_content], | ||
} | ||
rejected = { | ||
"role": role_map[sample[field_rejected][field_message_role]], | ||
"content": sample[field_rejected][field_message_content], | ||
} | ||
|
||
result = {} | ||
result["prompt"] = tokenizer.apply_chat_template( | ||
messages, | ||
add_generation_prompt=True, | ||
chat_template=chat_template_str, | ||
tokenize=False, | ||
) | ||
|
||
result["chosen"] = tokenizer.apply_chat_template( | ||
[chosen], | ||
add_generation_prompt=False, | ||
chat_template=chat_template_str, | ||
tokenize=False, | ||
) | ||
chosen_strip_index = result["chosen"].find(chosen["content"]) | ||
result["chosen"] = result["chosen"][chosen_strip_index:] | ||
|
||
result["rejected"] = tokenizer.apply_chat_template( | ||
[rejected], | ||
add_generation_prompt=False, | ||
chat_template=chat_template_str, | ||
tokenize=False, | ||
) | ||
rejected_strip_index = result["rejected"].find(rejected["content"]) | ||
result["rejected"] = result["rejected"][rejected_strip_index:] | ||
|
||
return result | ||
|
||
return transform_fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
"""data handling specific to DPO""" | ||
|
||
import inspect | ||
import logging | ||
from functools import partial | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
""" | ||
tests for chat_template prompt strategy | ||
""" | ||
|
||
import unittest | ||
|
||
import pytest | ||
from datasets import Dataset | ||
from transformers import AutoTokenizer | ||
|
||
from axolotl.prompt_strategies.dpo.chat_template import default | ||
from axolotl.utils.dict import DictDefault | ||
|
||
|
||
@pytest.fixture(name="assistant_dataset") | ||
def fixture_assistant_dataset(): | ||
# pylint: disable=duplicate-code | ||
return Dataset.from_list( | ||
[ | ||
{ | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": "hello", | ||
}, | ||
{ | ||
"role": "assistant", | ||
"content": "hello", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "goodbye", | ||
}, | ||
], | ||
"chosen": { | ||
"role": "assistant", | ||
"content": "goodbye", | ||
}, | ||
"rejected": { | ||
"role": "assistant", | ||
"content": "party on", | ||
}, | ||
} | ||
] | ||
) | ||
|
||
|
||
@pytest.fixture(name="custom_assistant_dataset") | ||
def fixture_custom_assistant_dataset(): | ||
# pylint: disable=duplicate-code | ||
return Dataset.from_list( | ||
[ | ||
{ | ||
"conversation": [ | ||
{ | ||
"speaker": "human", | ||
"text": "hello", | ||
}, | ||
{ | ||
"speaker": "agent", | ||
"text": "hello", | ||
}, | ||
{ | ||
"speaker": "human", | ||
"text": "goodbye", | ||
}, | ||
], | ||
"better": { | ||
"speaker": "agent", | ||
"text": "goodbye", | ||
}, | ||
"worse": { | ||
"speaker": "agent", | ||
"text": "party on", | ||
}, | ||
} | ||
] | ||
) | ||
|
||
|
||
@pytest.fixture(name="llama3_tokenizer") | ||
def fixture_llama3_tokenizer(): | ||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") | ||
tokenizer.eos_token = "<|eot_id|>" | ||
|
||
return tokenizer | ||
|
||
|
||
class TestAssistantDPOChatTemplateLlama3: | ||
""" | ||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. | ||
""" | ||
|
||
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset): | ||
# pylint: disable=duplicate-code | ||
transform_fn = default( | ||
DictDefault( | ||
{ | ||
"chat_template": "llama3", | ||
"datasets": [ | ||
{ | ||
"chat_template": "llama3", | ||
} | ||
], | ||
} | ||
) | ||
) | ||
result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer) | ||
assert result["prompt"] == ( | ||
"<|begin_of_text|>" | ||
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>" | ||
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>" | ||
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>" | ||
+ "<|start_header_id|>assistant<|end_header_id|>\n\n" | ||
) | ||
assert result["chosen"] == "goodbye<|eot_id|>" | ||
assert result["rejected"] == "party on<|eot_id|>" | ||
|
||
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset): | ||
# pylint: disable=duplicate-code | ||
transform_fn = default( | ||
DictDefault( | ||
{ | ||
"chat_template": "llama3", | ||
"datasets": [ | ||
{ | ||
"chat_template": "llama3", | ||
"field_messages": "conversation", | ||
"field_chosen": "better", | ||
"field_rejected": "worse", | ||
"message_field_role": "speaker", | ||
"message_field_content": "text", | ||
"roles": { | ||
"user": ["human"], | ||
"assistant": ["agent"], | ||
"system": ["sys"], | ||
}, | ||
} | ||
], | ||
} | ||
) | ||
) | ||
result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer) | ||
assert result["prompt"] == ( | ||
"<|begin_of_text|>" | ||
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>" | ||
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>" | ||
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>" | ||
+ "<|start_header_id|>assistant<|end_header_id|>\n\n" | ||
) | ||
assert result["chosen"] == "goodbye<|eot_id|>" | ||
assert result["rejected"] == "party on<|eot_id|>" | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |