From db4794d6e16845e49c92f700974e8d89d57c7bfd Mon Sep 17 00:00:00 2001 From: JohanWork <39947546+JohanWork@users.noreply.github.com> Date: Tue, 23 Jan 2024 05:30:26 +0100 Subject: [PATCH] Feat(test): Add tests for alpaca chatml prompt tokenizer (#1088) * draft for adding test for tokenizer * clean up * clean up * fix pre commit * fix pylint * Revert "fix pylint" This reverts commit cd2cda3cdae6f31f6d038a0673c2c7abd8e8e46a. * add pylint exception for pytest fixture * update comments * Apply suggestions from code review Co-authored-by: NanoCode012 * update spelling and import promptstyle * reaname, restrucure * clean up * add fmt:on --------- Co-authored-by: NanoCode012 --- tests/prompt_strategies/test_alpaca.py | 116 +++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 tests/prompt_strategies/test_alpaca.py diff --git a/tests/prompt_strategies/test_alpaca.py b/tests/prompt_strategies/test_alpaca.py new file mode 100644 index 0000000000..9c97e40521 --- /dev/null +++ b/tests/prompt_strategies/test_alpaca.py @@ -0,0 +1,116 @@ +""" +Test module for alpaca integration w chatml +""" +import pytest +from datasets import Dataset +from tokenizers import AddedToken +from transformers import AutoTokenizer + +from axolotl.datasets import TokenizedPromptDataset +from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter, PromptStyle + + +@pytest.fixture(name="alpaca_dataset") +def fixture_alpaca_dataset(): + return Dataset.from_list( + [ + { + "instruction": "Evaluate this sentence for spelling and grammar mistakes", + "input": "He finnished his meal and left the resturant", + "output": "He finished his meal and left the restaurant.", + } + ] + ) + + +@pytest.fixture(name="tokenizer") +def fixture_tokenizer(): + # pylint: disable=all + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + tokenizer.add_special_tokens( + { + "eos_token": AddedToken( + "<|im_end|>", rstrip=False, lstrip=False, normalized=False + ) + } + ) + tokenizer.add_tokens( + [ + AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False), + ] + ) + + return tokenizer + + +class TestAlpacaChatml: + """ + Test class for alpaca prompter + """ + + def test_no_double_im_end(self, alpaca_dataset, tokenizer): + strategy = AlpacaPromptTokenizingStrategy( + AlpacaPrompter(prompt_style=PromptStyle.CHATML.value), + tokenizer, + False, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, alpaca_dataset, process_count=1 + ) + + input_ids = dataset_wrapper[0]["input_ids"] + # fmt: off + assert input_ids == [ + 1, # Bos + 32001, 1587, 13, 20548, 336, 349, 396, 13126, 369, 13966, 264, 3638, 28725, 5881, 1360, 395, 396, 2787, 369, 5312, 3629, 2758, 28723, 12018, 264, 2899, 369, 6582, 1999, 2691, 274, 272, 2159, 28723, 32000, 28705, 13, # instruction + 32001, 2188, 13, 16627, 11931, 456, 12271, 354, 668, 3572, 304, 18756, 3479, 17179, 13, 2428, 854, 28711, 1497, 516, 11314, 304, 1749, 272, 1846, 324, 440, 32000, 28705, 13, # input + 32001, 13892, 13, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000, # output + ] + # fmt: on + + def test_no_train_on_input(self, alpaca_dataset, tokenizer): + strategy = AlpacaPromptTokenizingStrategy( + AlpacaPrompter(prompt_style=PromptStyle.CHATML.value), + tokenizer, + False, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, alpaca_dataset, process_count=1 + ) + + labels = dataset_wrapper[0]["labels"] + # fmt: off + assert labels == [ + -100, # bos + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # instruction + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # input + -100, -100, -100, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000, # Output + ] + # fmt: on + + def test_w_train_on_input(self, alpaca_dataset, tokenizer): + strategy = AlpacaPromptTokenizingStrategy( + AlpacaPrompter(prompt_style=PromptStyle.CHATML.value), + tokenizer, + True, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, alpaca_dataset, process_count=1 + ) + + labels = dataset_wrapper[0]["labels"] + # fmt: off + assert labels == [ + 1, # Bos + 32001, 1587, 13, 20548, 336, 349, 396, 13126, 369, 13966, 264, 3638, 28725, 5881, 1360, 395, 396, 2787, 369, 5312, 3629, 2758, 28723, 12018, 264, 2899, 369, 6582, 1999, 2691, 274, 272, 2159, 28723, 32000, 28705, 13, # instruction + 32001, 2188, 13, 16627, 11931, 456, 12271, 354, 668, 3572, 304, 18756, 3479, 17179, 13, 2428, 854, 28711, 1497, 516, 11314, 304, 1749, 272, 1846, 324, 440, 32000, 28705, 13, # input + 32001, 13892, 13, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000, # output + ] + # fmt: on