Skip to content

Commit

Permalink
fix: chat_template masking due to truncation, consolidate turn build …
Browse files Browse the repository at this point in the history
…and keys within field (#2123) [skip ci]

* fix: chat_template masking due to truncation, consolidate turn build and keys within field

* fix: revert roles change

* fix: handling of training and training_detail

* fix: do not skip setting eos mask even if failed finding turn boundary

* fix: truncate reward modelling outputs
  • Loading branch information
NanoCode012 authored Dec 9, 2024
1 parent 3862267 commit 5d6b088
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 74 deletions.
30 changes: 27 additions & 3 deletions src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def tokenize_prompt(self, prompt):
:return:
"""

max_length = self.prompter.max_length

self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
Expand All @@ -39,6 +41,16 @@ def tokenize_prompt(self, prompt):
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)

if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
)

chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
chosen_tokenized["attention_mask"] = chosen_tokenized["attention_mask"][
:max_length
]

self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
Expand All @@ -52,6 +64,18 @@ def tokenize_prompt(self, prompt):
)
rejected_tokenized = super().tokenize_prompt(prompt)

if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
)

rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
:max_length
]
rejected_tokenized["attention_mask"] = rejected_tokenized["attention_mask"][
:max_length
]

return {
"input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"],
Expand Down Expand Up @@ -80,9 +104,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1
if not cfg.reward_model
else cfg.sequence_len,
"max_length": (
cfg.sequence_len + 1 if not cfg.reward_model else cfg.sequence_len
),
}

strategy_params = {
Expand Down
158 changes: 87 additions & 71 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
"gpt": "assistant",
"system": "system",
}

self.message_field_role = message_field_role
self.message_field_content = message_field_content
self.message_field_training = message_field_training
Expand All @@ -53,21 +54,9 @@ def __init__(
self.drop_system_message = drop_system_message

def build_prompt(self, conversation, add_generation_prompt=False, images=None):
turns = [
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
"training": t.get(self.message_field_training, None),
}
for t in conversation
]

if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]

if self.processor:
text = self.processor.apply_chat_template(
turns,
conversation,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
Expand All @@ -76,8 +65,6 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
text=text,
images=images,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
)
# workaround since processor works in batches instead of single examples
for k, val in batch.items():
Expand All @@ -88,9 +75,7 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
return batch

return self.tokenizer.apply_chat_template(
turns,
truncation=True,
max_length=self.max_length,
conversation,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
Expand Down Expand Up @@ -215,7 +200,14 @@ def __init__(
train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []

self.roles_to_train = []
if roles_to_train:
# map roles if exist in prompter.roles else use the role as is
self.roles_to_train = [
prompter.roles.get(role, role) for role in roles_to_train
]

self.train_on_eos = train_on_eos
self.images = "images"

Expand Down Expand Up @@ -262,37 +254,38 @@ def tokenize_prompt(self, prompt):

return tokenized_prompt

turns = prompt[self.messages]
turns = self.get_conversation_thread(prompt)
input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids)

last_eos_idx = -1
for index, turn in enumerate(turns):
role = turn.get(self.prompter.message_field_role)
content = turn.get(self.prompter.message_field_content)
train_turn = turn.get(self.prompter.message_field_training)
train_detail = turn.get(self.prompter.message_field_training_detail)
role = turn.get("role")
content = turn.get("content")
train_turn = turn.get("training")
train_detail = turn.get("training_detail")

LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
)

should_train = (
train_turn
if train_turn is not None
else (
bool(train_detail is not None)
if train_detail is not None
else self.train_on_inputs or role in self.roles_to_train
)
)
should_train = None
if train_turn is not None:
should_train = train_turn
elif train_detail is not None:
should_train = bool(train_detail)
else:
should_train = self.train_on_inputs or role in self.roles_to_train

LOG.debug(f"Should train: {should_train}")

turn_start_idx, turn_end_idx = self.find_turn(
conversation_ids=input_ids, turn=index, turn_content=turn
)

if turn_start_idx == -1 or turn_end_idx == -1:
LOG.warning(f"Failed to find boundaries for turn {index}")

LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")

if should_train and turn_start_idx != -1 and turn_end_idx != -1:
Expand All @@ -313,7 +306,9 @@ def tokenize_prompt(self, prompt):
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
LOG.debug(
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
)

LOG.debug(f"Labels after processing turn {index}: {labels}")

Expand Down Expand Up @@ -351,52 +346,73 @@ def find_eos_token(self, input_ids, start_idx):
return i
return -1

def find_turn(self, conversation_ids, turn, turn_content):
def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict):
"""
Locate the starting and ending indices of the specified turn in a conversation.
Args:
conversation_ids (list[int]): Token IDs representing the conversation.
turn (int): The turn number to locate (based on EOS tokens).
turn_content (str): String containing the content of the turn.
Returns:
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
Returns (-1, -1) if the turn content is not found.
"""
content = turn_content.get(self.prompter.message_field_content, "")
content = turn_content.get("content")
content_ids = self.tokenizer.encode(content, add_special_tokens=False)

eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0

# Locate the starting index after the specified number of EOS tokens
for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn:
start_search_idx = (
i + 1
) # Start searching after the specified turn's EOS token
break

# Find the start index of the content within the conversation
start_idx = -1
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
if conversation_ids[i : i + len(content_ids)] == content_ids:
start_idx = i
break

if start_idx != -1:
end_idx = start_idx + len(content_ids)
LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}")

if not content_ids:
LOG.warning(f"Empty content for turn {turn}")
return -1, -1

# For first turn, start from beginning
if turn == 0:
start_search_idx = 0
else:
end_idx = -1
# For subsequent turns, find the previous EOS token
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0

for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn: # Find the nth EOS token where n = turn
start_search_idx = i + 1
break

# we can optimize this to only search for a few tokens from start_search_idx
# but it would risk missing the content if it's not found within the first few tokens or
# if start_search_idx cannot be found above.
last_index = len(conversation_ids) - len(content_ids) + 1

if last_index < start_search_idx:
LOG.warning(
f"last_index to search is less than start_search_idx for turn {turn}"
)
return -1, -1

# Search for content starting from start_search_idx
first_elem = content_ids[0]
for i in range(start_search_idx, last_index):
# Quick check of first element before doing full comparison
if conversation_ids[i] == first_elem:
# Check if the rest of the content matches
if conversation_ids[i : i + len(content_ids)] == content_ids:
LOG.debug(f"Found turn {turn} content at position {i}")
return i, i + len(content_ids)

return start_idx, end_idx
return -1, -1

def get_conversation_thread(self, prompt):
return prompt[self.messages]
turns = [
{
"role": self.prompter.roles[t[self.prompter.message_field_role]],
"content": t[self.prompter.message_field_content],
"training": t.get(self.prompter.message_field_training),
"training_detail": t.get(self.prompter.message_field_training_detail),
}
for t in prompt[self.messages]
]

if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]

return turns

def get_images(self, prompt):
return prompt.get(self.images, None)
Expand Down

0 comments on commit 5d6b088

Please sign in to comment.