From 5d6b088997d3992d1ce591087d04ae1f968dbd56 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 10 Dec 2024 01:49:38 +0700 Subject: [PATCH] fix: chat_template masking due to truncation, consolidate turn build and keys within field (#2123) [skip ci] * fix: chat_template masking due to truncation, consolidate turn build and keys within field * fix: revert roles change * fix: handling of training and training_detail * fix: do not skip setting eos mask even if failed finding turn boundary * fix: truncate reward modelling outputs --- .../bradley_terry/chat_template.py | 30 +++- .../prompt_strategies/chat_template.py | 158 ++++++++++-------- 2 files changed, 114 insertions(+), 74 deletions(-) diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index fa85cdcb26..4f60842c5f 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -28,6 +28,8 @@ def tokenize_prompt(self, prompt): :return: """ + max_length = self.prompter.max_length + self.messages = "chosen_messages" # pylint: disable=duplicate-code prompt[self.messages] = [] @@ -39,6 +41,16 @@ def tokenize_prompt(self, prompt): prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]}) chosen_tokenized = super().tokenize_prompt(prompt) + if len(chosen_tokenized["input_ids"]) > max_length: + LOG.warning( + f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}", + ) + + chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length] + chosen_tokenized["attention_mask"] = chosen_tokenized["attention_mask"][ + :max_length + ] + self.messages = "rejected_messages" # pylint: disable=duplicate-code prompt[self.messages] = [] @@ -52,6 +64,18 @@ def tokenize_prompt(self, prompt): ) rejected_tokenized = super().tokenize_prompt(prompt) + if len(rejected_tokenized["input_ids"]) > max_length: + LOG.warning( + f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}", + ) + + rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][ + :max_length + ] + rejected_tokenized["attention_mask"] = rejected_tokenized["attention_mask"][ + :max_length + ] + return { "input_ids_chosen": chosen_tokenized["input_ids"], "attention_mask_chosen": chosen_tokenized["attention_mask"], @@ -80,9 +104,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): "roles": ds_cfg.get("roles"), "drop_system_message": ds_cfg.get("drop_system_message", False), # we need to add one for detecting sequences with exceeding the `sequence_len` limit. - "max_length": cfg.sequence_len + 1 - if not cfg.reward_model - else cfg.sequence_len, + "max_length": ( + cfg.sequence_len + 1 if not cfg.reward_model else cfg.sequence_len + ), } strategy_params = { diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 0946a4b8c7..35c9311678 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -42,6 +42,7 @@ def __init__( "gpt": "assistant", "system": "system", } + self.message_field_role = message_field_role self.message_field_content = message_field_content self.message_field_training = message_field_training @@ -53,21 +54,9 @@ def __init__( self.drop_system_message = drop_system_message def build_prompt(self, conversation, add_generation_prompt=False, images=None): - turns = [ - { - "role": self.roles[t[self.message_field_role]], - "content": t[self.message_field_content], - "training": t.get(self.message_field_training, None), - } - for t in conversation - ] - - if self.drop_system_message and turns[0]["role"] == "system": - turns = turns[1:] - if self.processor: text = self.processor.apply_chat_template( - turns, + conversation, chat_template=self.chat_template, tokenize=False, add_generation_prompt=add_generation_prompt, @@ -76,8 +65,6 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None): text=text, images=images, return_tensors="pt", - truncation=True, - max_length=self.max_length, ) # workaround since processor works in batches instead of single examples for k, val in batch.items(): @@ -88,9 +75,7 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None): return batch return self.tokenizer.apply_chat_template( - turns, - truncation=True, - max_length=self.max_length, + conversation, add_generation_prompt=add_generation_prompt, chat_template=self.chat_template, ) @@ -215,7 +200,14 @@ def __init__( train_on_eos=None, ): super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) - self.roles_to_train = roles_to_train if roles_to_train is not None else [] + + self.roles_to_train = [] + if roles_to_train: + # map roles if exist in prompter.roles else use the role as is + self.roles_to_train = [ + prompter.roles.get(role, role) for role in roles_to_train + ] + self.train_on_eos = train_on_eos self.images = "images" @@ -262,30 +254,28 @@ def tokenize_prompt(self, prompt): return tokenized_prompt - turns = prompt[self.messages] + turns = self.get_conversation_thread(prompt) input_ids = self.prompter.build_prompt(turns) labels = [IGNORE_TOKEN_ID] * len(input_ids) last_eos_idx = -1 for index, turn in enumerate(turns): - role = turn.get(self.prompter.message_field_role) - content = turn.get(self.prompter.message_field_content) - train_turn = turn.get(self.prompter.message_field_training) - train_detail = turn.get(self.prompter.message_field_training_detail) + role = turn.get("role") + content = turn.get("content") + train_turn = turn.get("training") + train_detail = turn.get("training_detail") LOG.debug( f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}" ) - should_train = ( - train_turn - if train_turn is not None - else ( - bool(train_detail is not None) - if train_detail is not None - else self.train_on_inputs or role in self.roles_to_train - ) - ) + should_train = None + if train_turn is not None: + should_train = train_turn + elif train_detail is not None: + should_train = bool(train_detail) + else: + should_train = self.train_on_inputs or role in self.roles_to_train LOG.debug(f"Should train: {should_train}") @@ -293,6 +283,9 @@ def tokenize_prompt(self, prompt): conversation_ids=input_ids, turn=index, turn_content=turn ) + if turn_start_idx == -1 or turn_end_idx == -1: + LOG.warning(f"Failed to find boundaries for turn {index}") + LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}") if should_train and turn_start_idx != -1 and turn_end_idx != -1: @@ -313,7 +306,9 @@ def tokenize_prompt(self, prompt): labels[turn_start_idx:turn_end_idx] = input_ids[ turn_start_idx:turn_end_idx ] - LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}") + LOG.debug( + f"Set labels for training from {turn_start_idx} to {turn_end_idx}" + ) LOG.debug(f"Labels after processing turn {index}: {labels}") @@ -351,52 +346,73 @@ def find_eos_token(self, input_ids, start_idx): return i return -1 - def find_turn(self, conversation_ids, turn, turn_content): + def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict): """ Locate the starting and ending indices of the specified turn in a conversation. - - Args: - conversation_ids (list[int]): Token IDs representing the conversation. - turn (int): The turn number to locate (based on EOS tokens). - turn_content (str): String containing the content of the turn. - - Returns: - tuple: (start_idx, end_idx) indices of the start and end of the turn content. - Returns (-1, -1) if the turn content is not found. """ - content = turn_content.get(self.prompter.message_field_content, "") + content = turn_content.get("content") content_ids = self.tokenizer.encode(content, add_special_tokens=False) - eos_token_id = self.tokenizer.eos_token_id - eos_count = 0 - start_search_idx = 0 - - # Locate the starting index after the specified number of EOS tokens - for i, token_id in enumerate(conversation_ids): - if token_id == eos_token_id: - eos_count += 1 - if eos_count == turn: - start_search_idx = ( - i + 1 - ) # Start searching after the specified turn's EOS token - break - - # Find the start index of the content within the conversation - start_idx = -1 - for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1): - if conversation_ids[i : i + len(content_ids)] == content_ids: - start_idx = i - break - - if start_idx != -1: - end_idx = start_idx + len(content_ids) + LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}") + + if not content_ids: + LOG.warning(f"Empty content for turn {turn}") + return -1, -1 + + # For first turn, start from beginning + if turn == 0: + start_search_idx = 0 else: - end_idx = -1 + # For subsequent turns, find the previous EOS token + eos_token_id = self.tokenizer.eos_token_id + eos_count = 0 + start_search_idx = 0 + + for i, token_id in enumerate(conversation_ids): + if token_id == eos_token_id: + eos_count += 1 + if eos_count == turn: # Find the nth EOS token where n = turn + start_search_idx = i + 1 + break + + # we can optimize this to only search for a few tokens from start_search_idx + # but it would risk missing the content if it's not found within the first few tokens or + # if start_search_idx cannot be found above. + last_index = len(conversation_ids) - len(content_ids) + 1 + + if last_index < start_search_idx: + LOG.warning( + f"last_index to search is less than start_search_idx for turn {turn}" + ) + return -1, -1 + + # Search for content starting from start_search_idx + first_elem = content_ids[0] + for i in range(start_search_idx, last_index): + # Quick check of first element before doing full comparison + if conversation_ids[i] == first_elem: + # Check if the rest of the content matches + if conversation_ids[i : i + len(content_ids)] == content_ids: + LOG.debug(f"Found turn {turn} content at position {i}") + return i, i + len(content_ids) - return start_idx, end_idx + return -1, -1 def get_conversation_thread(self, prompt): - return prompt[self.messages] + turns = [ + { + "role": self.prompter.roles[t[self.prompter.message_field_role]], + "content": t[self.prompter.message_field_content], + "training": t.get(self.prompter.message_field_training), + "training_detail": t.get(self.prompter.message_field_training_detail), + } + for t in prompt[self.messages] + ] + + if self.prompter.drop_system_message and turns[0]["role"] == "system": + turns = turns[1:] + + return turns def get_images(self, prompt): return prompt.get(self.images, None)