diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index e056c7f509..16f38218cd 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -18,10 +18,10 @@ def encode_pretraining( - tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] + tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List] ) -> Dict[str, List]: res = tokenizer( - examples, + examples["text"], truncation=True, max_length=max_tokens - 2, add_special_tokens=True, diff --git a/tests/test_data.py b/tests/test_data.py index 16af089a06..9d7f5a0412 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -35,7 +35,7 @@ def test_encode_pretraining(self): "hello, hello", ] } - result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"]) + result = encode_pretraining(self.tokenizer, self.max_tokens, examples) self.assertEqual(len(result["input_ids"]), 3)