Skip to content

Commit

Permalink
test: add tests to check that completion alignment sampling works as …
Browse files Browse the repository at this point in the history
…intended
  • Loading branch information
Karl-Johan Alm committed Oct 7, 2023
1 parent 76ce0e5 commit 2ef1f90
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions tests/test_prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
InstructionWSystemPromptTokenizingStrategy,
SystemDataPrompter,
)
from axolotl.prompt_strategies.completion import (
CompletionPrompter,
CompletionPromptTokenizingStrategy,
)
from axolotl.prompt_strategies.llama2_chat import (
Llama2ChatPrompter,
LLama2ChatTokenizingStrategy,
Expand Down Expand Up @@ -45,6 +49,7 @@ def setUp(self) -> None:
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
}
)

Expand Down Expand Up @@ -129,6 +134,65 @@ def test_alpaca(self):
assert example["labels"][world_idx] == 6324
assert example["labels"][world_idx - 1] == -100

def test_completion_strategy(self):
"""
tests the completion prompt tokenization strategy
"""
# pylint: disable=duplicate-code
self.tokenizer.padding_side = "left"
for text in [
# ['▁Once', '▁upon', '▁a', '▁time', ',', '▁there', '▁was', '▁a', '▁dog', '.'] [10]
"Once upon a time, there was a dog.", # fits in one sample at 12 ctxlen
# ['▁Once', '▁upon', '▁a', '▁time', ',', '▁there', '▁was', '▁a', '▁dog', '.', '▁The', '▁dog', '▁was', '▁very', '▁happy', '.'] [16]
"Once upon a time, there was a dog. The dog was very happy.", # fits in two samples
# ['▁Once', '▁upon', '▁a', '▁time', ',', '▁there', '▁was', '▁a', '▁dog', '.', '▁The', '▁dog', '▁was', '▁very', '▁happy', '.', '▁It', '▁was', '▁in', '▁fact', '▁so', '▁happy', '▁that', '▁it', '▁emb', 'ark', 'ed', '▁upon', '▁a', '▁cr', 'us', 'ade', '▁to', '▁save', '▁human', 'ity', '▁from', '▁the', '▁ev', 'ils', '▁of', '▁man', 'kind', '.'] [44]
"Once upon a time, there was a dog. The dog was very happy. It was in fact so happy that it embarked upon a crusade to save humanity from the evils of mankind.", # requires 4 samples
]:
prompt_sample = {"text": [text]}
tokenized = self.tokenizer.tokenize(text)

strat = CompletionPromptTokenizingStrategy(
CompletionPrompter(),
self.tokenizer,
False,
12,
max_length=12 * 64,
align_samples=self.tokenizer.padding_side == "left",
)

example = strat.tokenize_prompt(prompt_sample)
# The first sample should have 0+ padding followed by the start of the text
# All padding should also have attention mask 0
is_padding = True
did_end = False
tokenized_idx = 0
for sample_idx, sample in enumerate(example["input_ids"]):
attention_mask = example["attention_mask"][sample_idx]
comp_tokens = self.tokenizer.convert_ids_to_tokens(sample)
for idx, token in enumerate(sample):
if tokenized_idx == len(tokenized):
# We must have reached the end of the tokenized text
assert token == self.tokenizer.eos_token_id
assert idx + 1 == len(sample)
assert not did_end
did_end = True
continue
if is_padding:
if token != self.tokenizer.pad_token_id:
# Must be the BOS token
assert token == self.tokenizer.bos_token_id
assert attention_mask[idx] == 1
is_padding = False
continue
assert attention_mask[idx] == 0
else:
comp_token = comp_tokens[idx]
assert tokenized[tokenized_idx] == comp_token
tokenized_idx += 1
# We must have reached the end of the tokenized text
assert tokenized_idx == len(tokenized)
assert did_end


class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
"""
Expand Down

0 comments on commit 2ef1f90

Please sign in to comment.