Skip to content

Commit

Permalink
Feat(test): Add tests for alpaca chatml prompt tokenizer (#1088)
Browse files Browse the repository at this point in the history
* draft for adding test for tokenizer

* clean up

* clean up

* fix pre commit

* fix pylint

* Revert "fix pylint"

This reverts commit cd2cda3.

* add pylint exception for pytest fixture

* update comments

* Apply suggestions from code review

Co-authored-by: NanoCode012 <[email protected]>

* update spelling and import promptstyle

* reaname, restrucure

* clean up

* add fmt:on

---------

Co-authored-by: NanoCode012 <[email protected]>
  • Loading branch information
JohanWork and NanoCode012 authored Jan 23, 2024
1 parent 6f6aa9d commit db4794d
Showing 1 changed file with 116 additions and 0 deletions.
116 changes: 116 additions & 0 deletions tests/prompt_strategies/test_alpaca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Test module for alpaca integration w chatml
"""
import pytest
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer

from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle


@pytest.fixture(name="alpaca_dataset")
def fixture_alpaca_dataset():
return Dataset.from_list(
[
{
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
"input": "He finnished his meal and left the resturant",
"output": "He finished his meal and left the restaurant.",
}
]
)


@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
# pylint: disable=all
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
tokenizer.add_tokens(
[
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
]
)

return tokenizer


class TestAlpacaChatml:
"""
Test class for alpaca prompter
"""

def test_no_double_im_end(self, alpaca_dataset, tokenizer):
strategy = AlpacaPromptTokenizingStrategy(
AlpacaPrompter(prompt_style=PromptStyle.CHATML.value),
tokenizer,
False, # train_on_inputs
2048, # sequence_len
)

dataset_wrapper = TokenizedPromptDataset(
strategy, alpaca_dataset, process_count=1
)

input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
assert input_ids == [
1, # Bos
32001, 1587, 13, 20548, 336, 349, 396, 13126, 369, 13966, 264, 3638, 28725, 5881, 1360, 395, 396, 2787, 369, 5312, 3629, 2758, 28723, 12018, 264, 2899, 369, 6582, 1999, 2691, 274, 272, 2159, 28723, 32000, 28705, 13, # instruction
32001, 2188, 13, 16627, 11931, 456, 12271, 354, 668, 3572, 304, 18756, 3479, 17179, 13, 2428, 854, 28711, 1497, 516, 11314, 304, 1749, 272, 1846, 324, 440, 32000, 28705, 13, # input
32001, 13892, 13, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000, # output
]
# fmt: on

def test_no_train_on_input(self, alpaca_dataset, tokenizer):
strategy = AlpacaPromptTokenizingStrategy(
AlpacaPrompter(prompt_style=PromptStyle.CHATML.value),
tokenizer,
False, # train_on_inputs
2048, # sequence_len
)

dataset_wrapper = TokenizedPromptDataset(
strategy, alpaca_dataset, process_count=1
)

labels = dataset_wrapper[0]["labels"]
# fmt: off
assert labels == [
-100, # bos
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # instruction
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # input
-100, -100, -100, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000, # Output
]
# fmt: on

def test_w_train_on_input(self, alpaca_dataset, tokenizer):
strategy = AlpacaPromptTokenizingStrategy(
AlpacaPrompter(prompt_style=PromptStyle.CHATML.value),
tokenizer,
True, # train_on_inputs
2048, # sequence_len
)

dataset_wrapper = TokenizedPromptDataset(
strategy, alpaca_dataset, process_count=1
)

labels = dataset_wrapper[0]["labels"]
# fmt: off
assert labels == [
1, # Bos
32001, 1587, 13, 20548, 336, 349, 396, 13126, 369, 13966, 264, 3638, 28725, 5881, 1360, 395, 396, 2787, 369, 5312, 3629, 2758, 28723, 12018, 264, 2899, 369, 6582, 1999, 2691, 274, 272, 2159, 28723, 32000, 28705, 13, # instruction
32001, 2188, 13, 16627, 11931, 456, 12271, 354, 668, 3572, 304, 18756, 3479, 17179, 13, 2428, 854, 28711, 1497, 516, 11314, 304, 1749, 272, 1846, 324, 440, 32000, 28705, 13, # input
32001, 13892, 13, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000, # output
]
# fmt: on

0 comments on commit db4794d

Please sign in to comment.