diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 4e30b81a71..23ea38da0f 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -45,6 +45,8 @@ def __init__( self.prompter = prompter self.tokenizer: PreTrainedTokenizer = tokenizer self.train_on_inputs = train_on_inputs + # sequence_len and max_length can be different for CompletionPromptTokenizingStrategy. + # TODO: Document how they are different. self.sequence_len = sequence_len self.max_length = sequence_len @@ -59,34 +61,31 @@ def supports_batched(self): def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False ) -> BatchEncoding: - result: BatchEncoding + empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) if not prompt: LOG.warning("Empty text requested for tokenization.") - result = BatchEncoding(data={"input_ids": [], "attention_mask": []}) - else: - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.max_length, - padding=False, - return_tensors=None, - ) + return empty + + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.max_length, + padding=False, + return_tensors=None, + ) if len(result["input_ids"]) == 0: LOG.warning("Tokenizer result is empty. You may want to audit your dataset") + return empty + if ( - len(result["input_ids"]) > 0 - and result["input_ids"][-1] != self.tokenizer.eos_token_id + result["input_ids"][-1] != self.tokenizer.eos_token_id and len(result["input_ids"]) < self.max_length and add_eos_token ): result["input_ids"].append(self.tokenizer.eos_token_id) result["attention_mask"].append(1) - if ( - len(result["input_ids"]) > 0 - and result["input_ids"][0] == self.tokenizer.bos_token_id - and strip_bos_token - ): + if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: result["input_ids"] = result["input_ids"][1:] result["attention_mask"] = result["attention_mask"][1:] @@ -122,7 +121,7 @@ def tokenize_prompt(self, prompt): if not self.train_on_inputs: user_prompt_len = len(tokenized_prompt["input_ids"]) # TODO this could be sped up using numpy array slicing - tokenized_prompt["labels"] = [-100] * user_prompt_len + tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len tokenized_res_prompt = self._tokenize( response, strip_bos_token=True, add_eos_token=True ) @@ -270,7 +269,7 @@ def tokenize_prompt(self, prompt): user_prompt_len = len(tokenized_user_prompt["input_ids"]) # TODO this could be sped up using numpy array slicing tokenized_full_prompt["labels"] = [ - -100 + IGNORE_INDEX ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] return tokenized_full_prompt @@ -334,6 +333,7 @@ def get_conversation_thread(self, prompt): return prompt["conversations"] def tokenize_prompt(self, prompt): + # Initial values. We will append to these as we go through the conversation. result, current_len = tokenize_prompt_default() conversation: Conversation = ( self.prompter._conversation.copy() # pylint: disable=protected-access @@ -355,62 +355,67 @@ def tokenize_prompt(self, prompt): for _, part in enumerate( self.prompter.build_prompt(self.get_conversation_thread(prompt)) ): - if isinstance(part, tuple): - if conversation.roles[0] in part[0]: - role = ( - part[0].replace(role_remap[0]["from"], role_remap[0]["to"]) - if role_remap - else part[0] - ) - turn = role + part[1] - # this is still the user query, we should - if not part[1].strip(): - LOG.warning(f"user turn has empty text: {prompt}") - res = self._tokenize( - turn, - add_eos_token=False, - strip_bos_token=True, - ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif conversation.roles[1] in part[0]: - # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID - role = ( - part[0].replace(role_remap[1]["from"], role_remap[1]["to"]) - if role_remap - else part[0] - ) - turn = role + part[1] - # this should be the assistant response, should end with an eos token - if not part[1].strip(): - LOG.warning(f"assistant turn has empty text: {prompt}") - res = self._tokenize( - turn, - add_eos_token=True, - strip_bos_token=True, - ) - role_res = self._tokenize( - role.rstrip(), - add_eos_token=False, - strip_bos_token=True, - ) - # not masked out from labels - labels = copy.deepcopy(res["input_ids"]) - len_role = len(role_res["input_ids"]) - labels[:len_role] = [IGNORE_TOKEN_ID] * min( - len_role, len(labels) - ) - elif part[0] == "": - turn = part[1] - # this is only ever the first part, should include the bos token and the user query - res = self._tokenize( - turn, add_eos_token=False, strip_bos_token=False - ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - else: - LOG.warning(f"unhandled role: {part[0]}") - continue + if not isinstance(part, tuple): + LOG.warning(f"expected tuple, got {part}") + continue + + user, assistant = conversation.roles + role, content = part + + # Uses "in" because role contains extra characters + if user in role: + role = ( + role.replace(role_remap[0]["from"], role_remap[0]["to"]) + if role_remap + else role + ) + turn = role + content + # this is still the user query, we should + if not content.strip(): + LOG.warning(f"user turn has empty text: {prompt}") + res = self._tokenize( + turn, + add_eos_token=False, + strip_bos_token=True, + ) + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + elif assistant in role: + # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID + role = ( + role.replace(role_remap[1]["from"], role_remap[1]["to"]) + if role_remap + else role + ) + turn = role + content + # this should be the assistant response, should end with an eos token + if not content.strip(): + LOG.warning(f"assistant turn has empty text: {prompt}") + res = self._tokenize( + turn, + add_eos_token=True, + strip_bos_token=True, + ) + role_res = self._tokenize( + role.rstrip(), + add_eos_token=False, + strip_bos_token=True, + ) + # not masked out from labels + labels = copy.deepcopy(res["input_ids"]) + len_role = len(role_res["input_ids"]) + labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels)) + elif role == "": + turn = content + # this is only ever the first part, should include the bos token and the user query + res = self._tokenize( + turn, add_eos_token=False, strip_bos_token=False + ) + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + else: + LOG.warning(f"unhandled role: {role}") + continue # pylint: disable=duplicate-code result, current_len = parse_tokenized_to_result( @@ -424,38 +429,6 @@ def tokenize_prompt(self, prompt): except (KeyError, AssertionError, IndexError) as err: raise InvalidDataException(str(err)) from err - def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): - if not prompt.strip(): - LOG.warning("Empty text requested for tokenization.") - result = BatchEncoding(data={"input_ids": [], "attention_mask": []}) - else: - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.sequence_len, - padding=False, - return_tensors=None, - ) - if ( - len(result["input_ids"]) > 0 - and result["input_ids"][-1] != self.tokenizer.eos_token_id - and len(result["input_ids"]) < self.sequence_len - and add_eos_token - ): - result["input_ids"].append(self.tokenizer.eos_token_id) - result["attention_mask"].append(1) - - if ( - len(result["input_ids"]) > 0 - and result["input_ids"][0] == self.tokenizer.bos_token_id - and strip_bos_token - ): - result["input_ids"] = result["input_ids"][1:] - result["attention_mask"] = result["attention_mask"][1:] - - result["labels"] = result["input_ids"].copy() - return result - def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: """