From 985819d89bec921e919e7e83042a869f04a25974 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Sun, 21 Jul 2024 06:10:42 -0700 Subject: [PATCH] Add a `chat_template` prompt strategy for DPO (#1725) * Implementing a basic chat_template strategy for DPO datasets This mimics the sft chat_template strategy such that users can: * Specify the messages field * Specify the per message role and content fields * speicfy the chosen and rejected fields * Let the tokenizer construct the raw prompt * Ensure the chosen and rejected fields don't have any prefix tokens * Adding additional dpo chat template unittests * Rename test class --- examples/llama-3/instruct-dpo-lora-8b.yml | 81 +++++++++ .../prompt_strategies/dpo/chat_template.py | 78 +++++++++ src/axolotl/utils/data/rl.py | 1 + src/axolotl/utils/tokenization.py | 2 +- .../test_dpo_chat_templates.py | 156 ++++++++++++++++++ 5 files changed, 317 insertions(+), 1 deletion(-) create mode 100644 examples/llama-3/instruct-dpo-lora-8b.yml create mode 100644 src/axolotl/prompt_strategies/dpo/chat_template.py create mode 100644 tests/prompt_strategies/test_dpo_chat_templates.py diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml new file mode 100644 index 0000000000..14febb810a --- /dev/null +++ b/examples/llama-3/instruct-dpo-lora-8b.yml @@ -0,0 +1,81 @@ +base_model: meta-llama/Meta-Llama-3-8B-Instruct +model_type: LlamaForCausalLM +tokenizer_type: AutoTokenizer + +load_in_8bit: true +load_in_4bit: false +strict: false + +chat_template: llama3 +rl: dpo +datasets: + - path: fozziethebeat/alpaca_messages_2k_dpo_test + type: chat_template.default + chat_template: llama3 + field_messages: conversation + field_chosen: chosen + field_rejected: rejected + message_field_role: role + message_field_content: content + roles: + system: + - system + user: + - user + assistant: + - assistant + +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./outputs/lora-out + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true +s2_attention: + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py new file mode 100644 index 0000000000..4f2f14098d --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -0,0 +1,78 @@ +""" +DPO prompt strategies for using tokenizer chat templates. +""" + +from axolotl.utils.chat_templates import chat_templates + + +def default( + cfg, dataset_idx=0, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument + ds_cfg = cfg["datasets"][dataset_idx] + chat_template_str = chat_templates(cfg.chat_template) + + field_messages = ds_cfg.get("field_messages", "messages") + field_chosen = ds_cfg.get("field_chosen", "chosen") + field_rejected = ds_cfg.get("field_rejected", "rejected") + field_message_role = ds_cfg.get("message_field_role", "role") + field_message_content = ds_cfg.get("message_field_content", "content") + role_map_inv = ds_cfg.get( + "roles", + { + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + ) + role_map = {} + for target, sources in role_map_inv.items(): + for source in sources: + role_map[source] = target + + def transform_fn(sample, tokenizer=None): + messages = sample[field_messages] + messages = [ + { + "role": role_map[m[field_message_role]], + "content": m[field_message_content], + } + for m in messages + ] + chosen = { + "role": role_map[sample[field_chosen][field_message_role]], + "content": sample[field_chosen][field_message_content], + } + rejected = { + "role": role_map[sample[field_rejected][field_message_role]], + "content": sample[field_rejected][field_message_content], + } + + result = {} + result["prompt"] = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=False, + ) + + result["chosen"] = tokenizer.apply_chat_template( + [chosen], + add_generation_prompt=False, + chat_template=chat_template_str, + tokenize=False, + ) + chosen_strip_index = result["chosen"].find(chosen["content"]) + result["chosen"] = result["chosen"][chosen_strip_index:] + + result["rejected"] = tokenizer.apply_chat_template( + [rejected], + add_generation_prompt=False, + chat_template=chat_template_str, + tokenize=False, + ) + rejected_strip_index = result["rejected"].find(rejected["content"]) + result["rejected"] = result["rejected"][rejected_strip_index:] + + return result + + return transform_fn diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 7416ca28bb..d0324e1ebd 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -1,4 +1,5 @@ """data handling specific to DPO""" + import inspect import logging from functools import partial diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 845296b7a6..f353aebec9 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -62,7 +62,7 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only): """Helper function to process and color tokens.""" colored_tokens = [ color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only) - for token in tokenizer.encode(tokens) + for token in tokenizer.encode(tokens, add_special_tokens=False) ] return colored_tokens diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py new file mode 100644 index 0000000000..cca48b1cf3 --- /dev/null +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -0,0 +1,156 @@ +""" +tests for chat_template prompt strategy +""" + +import unittest + +import pytest +from datasets import Dataset +from transformers import AutoTokenizer + +from axolotl.prompt_strategies.dpo.chat_template import default +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="assistant_dataset") +def fixture_assistant_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "messages": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "hello", + }, + { + "role": "user", + "content": "goodbye", + }, + ], + "chosen": { + "role": "assistant", + "content": "goodbye", + }, + "rejected": { + "role": "assistant", + "content": "party on", + }, + } + ] + ) + + +@pytest.fixture(name="custom_assistant_dataset") +def fixture_custom_assistant_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "conversation": [ + { + "speaker": "human", + "text": "hello", + }, + { + "speaker": "agent", + "text": "hello", + }, + { + "speaker": "human", + "text": "goodbye", + }, + ], + "better": { + "speaker": "agent", + "text": "goodbye", + }, + "worse": { + "speaker": "agent", + "text": "party on", + }, + } + ] + ) + + +@pytest.fixture(name="llama3_tokenizer") +def fixture_llama3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") + tokenizer.eos_token = "<|eot_id|>" + + return tokenizer + + +class TestAssistantDPOChatTemplateLlama3: + """ + Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. + """ + + def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset): + # pylint: disable=duplicate-code + transform_fn = default( + DictDefault( + { + "chat_template": "llama3", + "datasets": [ + { + "chat_template": "llama3", + } + ], + } + ) + ) + result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer) + assert result["prompt"] == ( + "<|begin_of_text|>" + + "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert result["chosen"] == "goodbye<|eot_id|>" + assert result["rejected"] == "party on<|eot_id|>" + + def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset): + # pylint: disable=duplicate-code + transform_fn = default( + DictDefault( + { + "chat_template": "llama3", + "datasets": [ + { + "chat_template": "llama3", + "field_messages": "conversation", + "field_chosen": "better", + "field_rejected": "worse", + "message_field_role": "speaker", + "message_field_content": "text", + "roles": { + "user": ["human"], + "assistant": ["agent"], + "system": ["sys"], + }, + } + ], + } + ) + ) + result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer) + assert result["prompt"] == ( + "<|begin_of_text|>" + + "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert result["chosen"] == "goodbye<|eot_id|>" + assert result["rejected"] == "party on<|eot_id|>" + + +if __name__ == "__main__": + unittest.main()