From 95e141d29558bfc29dacc66b1ce7b74dd4a9e86b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=8A=B9=EB=8D=95/Infrastructure=EA=B7=B8?= =?UTF-8?q?=EB=A3=B9=28YA=29?= Date: Mon, 9 Oct 2023 16:21:36 +0900 Subject: [PATCH 1/4] Enhance code readability of prompt_tokenizers.py --- src/axolotl/prompt_tokenizers.py | 193 +++++++++++++------------------ 1 file changed, 83 insertions(+), 110 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 4e30b81a71..22a66f876f 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -45,7 +45,6 @@ def __init__( self.prompter = prompter self.tokenizer: PreTrainedTokenizer = tokenizer self.train_on_inputs = train_on_inputs - self.sequence_len = sequence_len self.max_length = sequence_len @abc.abstractmethod @@ -59,34 +58,31 @@ def supports_batched(self): def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False ) -> BatchEncoding: - result: BatchEncoding + empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) if not prompt: LOG.warning("Empty text requested for tokenization.") - result = BatchEncoding(data={"input_ids": [], "attention_mask": []}) - else: - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.max_length, - padding=False, - return_tensors=None, - ) + return empty + + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.max_length, + padding=False, + return_tensors=None, + ) if len(result["input_ids"]) == 0: LOG.warning("Tokenizer result is empty. You may want to audit your dataset") + return empty + if ( - len(result["input_ids"]) > 0 - and result["input_ids"][-1] != self.tokenizer.eos_token_id + result["input_ids"][-1] != self.tokenizer.eos_token_id and len(result["input_ids"]) < self.max_length and add_eos_token ): result["input_ids"].append(self.tokenizer.eos_token_id) result["attention_mask"].append(1) - if ( - len(result["input_ids"]) > 0 - and result["input_ids"][0] == self.tokenizer.bos_token_id - and strip_bos_token - ): + if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: result["input_ids"] = result["input_ids"][1:] result["attention_mask"] = result["attention_mask"][1:] @@ -122,7 +118,7 @@ def tokenize_prompt(self, prompt): if not self.train_on_inputs: user_prompt_len = len(tokenized_prompt["input_ids"]) # TODO this could be sped up using numpy array slicing - tokenized_prompt["labels"] = [-100] * user_prompt_len + tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len tokenized_res_prompt = self._tokenize( response, strip_bos_token=True, add_eos_token=True ) @@ -270,7 +266,7 @@ def tokenize_prompt(self, prompt): user_prompt_len = len(tokenized_user_prompt["input_ids"]) # TODO this could be sped up using numpy array slicing tokenized_full_prompt["labels"] = [ - -100 + IGNORE_INDEX ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] return tokenized_full_prompt @@ -294,13 +290,13 @@ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): result = self.tokenizer( prompt, truncation=True, - max_length=self.sequence_len, + max_length=self.max_length, padding=False, return_tensors=None, ) if ( result["input_ids"][-1] != self.tokenizer.eos_token_id - and len(result["input_ids"]) < self.sequence_len + and len(result["input_ids"]) < self.max_length and add_eos_token ): result["input_ids"].append(self.tokenizer.eos_token_id) @@ -334,6 +330,7 @@ def get_conversation_thread(self, prompt): return prompt["conversations"] def tokenize_prompt(self, prompt): + # Initial values. We will append to these as we go through the conversation. result, current_len = tokenize_prompt_default() conversation: Conversation = ( self.prompter._conversation.copy() # pylint: disable=protected-access @@ -355,62 +352,70 @@ def tokenize_prompt(self, prompt): for _, part in enumerate( self.prompter.build_prompt(self.get_conversation_thread(prompt)) ): - if isinstance(part, tuple): - if conversation.roles[0] in part[0]: - role = ( - part[0].replace(role_remap[0]["from"], role_remap[0]["to"]) - if role_remap - else part[0] - ) - turn = role + part[1] - # this is still the user query, we should - if not part[1].strip(): - LOG.warning(f"user turn has empty text: {prompt}") - res = self._tokenize( - turn, - add_eos_token=False, - strip_bos_token=True, - ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif conversation.roles[1] in part[0]: - # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID - role = ( - part[0].replace(role_remap[1]["from"], role_remap[1]["to"]) - if role_remap - else part[0] - ) - turn = role + part[1] - # this should be the assistant response, should end with an eos token - if not part[1].strip(): - LOG.warning(f"assistant turn has empty text: {prompt}") - res = self._tokenize( - turn, - add_eos_token=True, - strip_bos_token=True, - ) - role_res = self._tokenize( - role.rstrip(), - add_eos_token=False, - strip_bos_token=True, - ) - # not masked out from labels - labels = copy.deepcopy(res["input_ids"]) - len_role = len(role_res["input_ids"]) - labels[:len_role] = [IGNORE_TOKEN_ID] * min( - len_role, len(labels) - ) - elif part[0] == "": - turn = part[1] - # this is only ever the first part, should include the bos token and the user query - res = self._tokenize( - turn, add_eos_token=False, strip_bos_token=False - ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - else: - LOG.warning(f"unhandled role: {part[0]}") - continue + + if not isinstance(part, tuple): + LOG.warning(f"expected tuple, got {part}") + continue + + user, assistant = conversation.roles + role, content = part + + # Uses "in" because role contains extra characters + if user in role: + role = ( + role.replace(role_remap[0]["from"], role_remap[0]["to"]) + if role_remap + else role + ) + turn = role + content + # this is still the user query, we should + if not content.strip(): + LOG.warning(f"user turn has empty text: {prompt}") + res = self._tokenize( + turn, + add_eos_token=False, + strip_bos_token=True, + ) + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + elif assistant in role: + # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID + role = ( + role.replace(role_remap[1]["from"], role_remap[1]["to"]) + if role_remap + else role + ) + turn = role + content + # this should be the assistant response, should end with an eos token + if not content.strip(): + LOG.warning(f"assistant turn has empty text: {prompt}") + res = self._tokenize( + turn, + add_eos_token=True, + strip_bos_token=True, + ) + role_res = self._tokenize( + role.rstrip(), + add_eos_token=False, + strip_bos_token=True, + ) + # not masked out from labels + labels = copy.deepcopy(res["input_ids"]) + len_role = len(role_res["input_ids"]) + labels[:len_role] = [IGNORE_TOKEN_ID] * min( + len_role, len(labels) + ) + elif role == "": + turn = content + # this is only ever the first part, should include the bos token and the user query + res = self._tokenize( + turn, add_eos_token=False, strip_bos_token=False + ) + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + else: + LOG.warning(f"unhandled role: {role}") + continue # pylint: disable=duplicate-code result, current_len = parse_tokenized_to_result( @@ -424,38 +429,6 @@ def tokenize_prompt(self, prompt): except (KeyError, AssertionError, IndexError) as err: raise InvalidDataException(str(err)) from err - def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): - if not prompt.strip(): - LOG.warning("Empty text requested for tokenization.") - result = BatchEncoding(data={"input_ids": [], "attention_mask": []}) - else: - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.sequence_len, - padding=False, - return_tensors=None, - ) - if ( - len(result["input_ids"]) > 0 - and result["input_ids"][-1] != self.tokenizer.eos_token_id - and len(result["input_ids"]) < self.sequence_len - and add_eos_token - ): - result["input_ids"].append(self.tokenizer.eos_token_id) - result["attention_mask"].append(1) - - if ( - len(result["input_ids"]) > 0 - and result["input_ids"][0] == self.tokenizer.bos_token_id - and strip_bos_token - ): - result["input_ids"] = result["input_ids"][1:] - result["attention_mask"] = result["attention_mask"][1:] - - result["labels"] = result["input_ids"].copy() - return result - def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: """ From 21fe299431e8323fd27f06c44f0fae5b7342a006 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=8A=B9=EB=8D=95/Infrastructure=EA=B7=B8?= =?UTF-8?q?=EB=A3=B9=28YA=29?= Date: Mon, 9 Oct 2023 16:33:49 +0900 Subject: [PATCH 2/4] Use max_length instead of sequence_len since they are the same values --- src/axolotl/prompt_strategies/completion.py | 4 ++-- src/axolotl/prompt_strategies/metharme.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/axolotl/prompt_strategies/completion.py b/src/axolotl/prompt_strategies/completion.py index 3285e667cb..bb30c27101 100644 --- a/src/axolotl/prompt_strategies/completion.py +++ b/src/axolotl/prompt_strategies/completion.py @@ -53,8 +53,8 @@ def tokenize_prompt(self, prompt): tokenized_full_prompt = self._tokenize(full_prompt) for key, val in tokenized_full_prompt.items(): - for i in range(0, len(val), self.sequence_len): - res[key].append(val[i : i + self.sequence_len]) + for i in range(0, len(val), self.max_length): + res[key].append(val[i : i + self.max_length]) return dict(res) diff --git a/src/axolotl/prompt_strategies/metharme.py b/src/axolotl/prompt_strategies/metharme.py index 52d77c00cf..62c5349bd0 100644 --- a/src/axolotl/prompt_strategies/metharme.py +++ b/src/axolotl/prompt_strategies/metharme.py @@ -31,7 +31,7 @@ def _tokenize( result = self.tokenizer( prompt, truncation=True, - max_length=self.sequence_len, + max_length=self.max_length, padding=False, return_tensors=None, ) @@ -43,7 +43,7 @@ def _tokenize( if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0: for _ in range(num_eos_tokens): - if len(result["input_ids"]) < self.sequence_len: + if len(result["input_ids"]) < self.max_length: result["input_ids"].append(self.tokenizer.eos_token_id) result["attention_mask"].append(1) From fbde27c26f8897ab6c1801436c5c0e878fdda18d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=8A=B9=EB=8D=95/Infrastructure=EA=B7=B8?= =?UTF-8?q?=EB=A3=B9=28YA=29?= Date: Wed, 11 Oct 2023 22:18:18 +0900 Subject: [PATCH 3/4] revert removing sequence_len --- src/axolotl/prompt_strategies/completion.py | 4 ++-- src/axolotl/prompt_strategies/metharme.py | 4 ++-- src/axolotl/prompt_tokenizers.py | 7 +++++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/axolotl/prompt_strategies/completion.py b/src/axolotl/prompt_strategies/completion.py index bb30c27101..3285e667cb 100644 --- a/src/axolotl/prompt_strategies/completion.py +++ b/src/axolotl/prompt_strategies/completion.py @@ -53,8 +53,8 @@ def tokenize_prompt(self, prompt): tokenized_full_prompt = self._tokenize(full_prompt) for key, val in tokenized_full_prompt.items(): - for i in range(0, len(val), self.max_length): - res[key].append(val[i : i + self.max_length]) + for i in range(0, len(val), self.sequence_len): + res[key].append(val[i : i + self.sequence_len]) return dict(res) diff --git a/src/axolotl/prompt_strategies/metharme.py b/src/axolotl/prompt_strategies/metharme.py index 62c5349bd0..52d77c00cf 100644 --- a/src/axolotl/prompt_strategies/metharme.py +++ b/src/axolotl/prompt_strategies/metharme.py @@ -31,7 +31,7 @@ def _tokenize( result = self.tokenizer( prompt, truncation=True, - max_length=self.max_length, + max_length=self.sequence_len, padding=False, return_tensors=None, ) @@ -43,7 +43,7 @@ def _tokenize( if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0: for _ in range(num_eos_tokens): - if len(result["input_ids"]) < self.max_length: + if len(result["input_ids"]) < self.sequence_len: result["input_ids"].append(self.tokenizer.eos_token_id) result["attention_mask"].append(1) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 22a66f876f..918514b198 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -45,6 +45,9 @@ def __init__( self.prompter = prompter self.tokenizer: PreTrainedTokenizer = tokenizer self.train_on_inputs = train_on_inputs + # sequence_len and max_length can be different for CompletionPromptTokenizingStrategy. + # TODO: Document how they are different. + self.sequence_len = sequence_len self.max_length = sequence_len @abc.abstractmethod @@ -290,13 +293,13 @@ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): result = self.tokenizer( prompt, truncation=True, - max_length=self.max_length, + max_length=self.sequence_len, padding=False, return_tensors=None, ) if ( result["input_ids"][-1] != self.tokenizer.eos_token_id - and len(result["input_ids"]) < self.max_length + and len(result["input_ids"]) < self.sequence_len and add_eos_token ): result["input_ids"].append(self.tokenizer.eos_token_id) From af04eb01d2e425a8be16dc72276f82ffaad7bf1a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 18 Oct 2023 21:33:37 -0400 Subject: [PATCH 4/4] chore: lint --- src/axolotl/prompt_tokenizers.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 918514b198..23ea38da0f 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -355,7 +355,6 @@ def tokenize_prompt(self, prompt): for _, part in enumerate( self.prompter.build_prompt(self.get_conversation_thread(prompt)) ): - if not isinstance(part, tuple): LOG.warning(f"expected tuple, got {part}") continue @@ -405,9 +404,7 @@ def tokenize_prompt(self, prompt): # not masked out from labels labels = copy.deepcopy(res["input_ids"]) len_role = len(role_res["input_ids"]) - labels[:len_role] = [IGNORE_TOKEN_ID] * min( - len_role, len(labels) - ) + labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels)) elif role == "": turn = content # this is only ever the first part, should include the bos token and the user query