From 54abc7ea10ababa5f0be02066dc22a2c63c517e7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 19 Sep 2023 08:09:56 -0400 Subject: [PATCH] improve handling for empty text on the tokenization step (#502) --- src/axolotl/prompt_tokenizers.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index b1aaeb3503..f30d0e3832 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -6,7 +6,7 @@ import logging from typing import Dict, List, Tuple, Union -from transformers import PreTrainedTokenizer +from transformers import BatchEncoding, PreTrainedTokenizer from axolotl.prompters import IGNORE_TOKEN_ID @@ -66,14 +66,21 @@ def _get_assistant_token(self): pass return False - def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False): - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.sequence_len, - padding=False, - return_tensors=None, - ) + def _tokenize( + self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False + ) -> BatchEncoding: + result: BatchEncoding + if not prompt.strip(): + LOG.warning("Empty text requested for tokenization.") + result = BatchEncoding(data={"input_ids": [], "attention_mask": []}) + else: + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.sequence_len, + padding=False, + return_tensors=None, + ) if len(result["input_ids"]) == 0: LOG.warning("Tokenizer result is empty. You may want to audit your dataset") if (