diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 177545170e..53f6e677e8 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -66,11 +66,13 @@ def _get_assistant_token(self): pass return False - def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False): - result: Union[Dict[str, List[Union[bool, int]]], BatchEncoding] + def _tokenize( + self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False + ) -> BatchEncoding: + result: BatchEncoding if not prompt.strip(): LOG.warning("Empty text requested for tokenization.") - result = {"input_ids": [], "attention_mask": []} + result = BatchEncoding(data={"input_ids": [], "attention_mask": []}) else: result = self.tokenizer( prompt,