From 651b7a31fcd50fcec4be27664adba35a0387b0b5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Jan 2024 09:33:38 -0500 Subject: [PATCH] fix double eos token for chatml (#1054) [skip ci] * fix double eos token for chatml * isolate fix to chatml conversation * fix add special tokens to include rstrip * add test for train_on_inputs for sharegpt * don't use rstrip for chatml --- src/axolotl/prompt_tokenizers.py | 6 +- tests/prompt_strategies/test_sharegpt.py | 153 +++++++++++++++++++++++ 2 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 tests/prompt_strategies/test_sharegpt.py diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index fe4f3b62f7..389ea9a5eb 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -392,9 +392,13 @@ def tokenize_prompt(self, prompt): # this should be the assistant response, should end with an eos token if not content.strip(): LOG.warning(f"assistant turn has empty text: {prompt}") + add_eos_token = not ( + conversation.name == "chatml" + and conversation.sep == self.tokenizer.eos_token + ) res = self._tokenize( turn, - add_eos_token=True, + add_eos_token=add_eos_token, strip_bos_token=True, ) role_res = self._tokenize( diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py new file mode 100644 index 0000000000..ce33a8c400 --- /dev/null +++ b/tests/prompt_strategies/test_sharegpt.py @@ -0,0 +1,153 @@ +""" +Test module for sharegpt integration w chatml +""" +import pytest +from datasets import Dataset +from tokenizers import AddedToken +from transformers import AutoTokenizer + +from axolotl.datasets import TokenizedPromptDataset +from axolotl.prompt_strategies.sharegpt import SimpleShareGPTPromptTokenizingStrategy +from axolotl.prompters import ShareGPTPrompterV2 + + +@pytest.fixture(name="sharegpt_dataset") +def fixture_sharegpt_dataset(): + return Dataset.from_list( + [ + { + "conversations": [ + { + "from": "system", + "value": "repeat", + }, + { + "from": "human", + "value": "hello", + }, + { + "from": "gpt", + "value": "hello", + }, + { + "from": "human", + "value": "goodbye", + }, + { + "from": "gpt", + "value": "goodbye", + }, + ] + } + ] + ) + + +@pytest.fixture(name="tokenizer") +def fixture_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + tokenizer.add_special_tokens( + { + "eos_token": AddedToken( + "<|im_end|>", rstrip=False, lstrip=False, normalized=False + ) + } + ) + tokenizer.add_tokens( + [ + AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False), + ] + ) + + return tokenizer + + +class TestSharegpt: + """ + Test class for sharegpt prompter + """ + + def test_no_double_im_end(self, sharegpt_dataset, tokenizer): + strategy = SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation="chatml", + role_key_model=None, + role_key_human=None, + ), + tokenizer, + False, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, sharegpt_dataset, process_count=1 + ) + + input_ids = dataset_wrapper[0]["input_ids"] + # fmt: off + assert input_ids == [ + # 28705, 13, is " \n" + 1, # bos + 32001, 1587, 13, 25997, 32000, 28705, 13, # system + 32001, 2188, 13, 21558, 32000, 28705, 13, # human + 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt + 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human + 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt + ] + # fmt: on + + def test_w_train_on_input(self, sharegpt_dataset, tokenizer): + strategy = SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation="chatml", + role_key_model=None, + role_key_human=None, + ), + tokenizer, + True, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, sharegpt_dataset, process_count=1 + ) + + labels = dataset_wrapper[0]["labels"] + # fmt: off + assert labels == [ + -100, # bos + -100, -100, -100, -100, -100, -100, -100, # system + -100, -100, -100, -100, -100, -100, -100, # human + -100, -100, 13, 21558, 32000, 28705, 13, # gpt + -100, -100, -100, -100, -100, -100, -100, -100, # human + -100, -100, 13, 12684, 17664, 32000, 28705, 13, # gpt + ] + # fmt: on + + # def test_no_train_on_input(self, sharegpt_dataset, tokenizer): + # strategy = SimpleShareGPTPromptTokenizingStrategy( + # ShareGPTPrompterV2( + # conversation="chatml", + # role_key_model=None, + # role_key_human=None, + # ), + # tokenizer, + # False, # train_on_inputs + # 2048, # sequence_len + # ) + # + # dataset_wrapper = TokenizedPromptDataset( + # strategy, sharegpt_dataset, process_count=1 + # ) + # + # labels = dataset_wrapper[0]["labels"] + # # fmt: off + # assert labels == [ + # 1, # bos + # 32001, 1587, 13, 25997, 32000, 28705, 13, # system + # 32001, 2188, 13, 21558, 32000, 28705, 13, # human + # 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt + # 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human + # 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt + # ] + # # fmt: on