Skip to content

Commit

Permalink
split dataset parsing into it's own component
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Mar 18, 2024
1 parent 34da274 commit c4af0c2
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions src/axolotl/prompt_strategies/orpo/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,12 @@ def load(
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
dataset_parser=ORPODatasetParsingStrategy(),
)


class ORPOTokenizingStrategy(PromptTokenizingStrategy):
"""
rejected_input_ids
input_ids
rejected_attention_mask
attention_mask
rejected_labels
labels
"""

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
class ORPODatasetParsingStrategy:
"""Strategy to parse chosen rejected dataset into messagelist"""

def get_chosen_conversation_thread(self, prompt) -> MessageList:
"""Dataset structure mappings"""
Expand Down Expand Up @@ -90,10 +77,32 @@ def get_rejected_conversation_thread(self, prompt) -> MessageList:
)
return MessageList(messages=messages)


class ORPOTokenizingStrategy(PromptTokenizingStrategy):
"""
rejected_input_ids
input_ids
rejected_attention_mask
attention_mask
rejected_labels
labels
"""

def __init__(
self,
*args,
dataset_parser=None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.dataset_parser = dataset_parser

def tokenize_prompt(self, prompt):
# pass the rejected prompt/row to the Prompter to get the formatted prompt
prompt_len = 0
rejected_message_list = self.get_rejected_conversation_thread(prompt)
rejected_message_list = self.dataset_parser.get_rejected_conversation_thread(
prompt
)
input_ids = []
labels = []
for _, (part, label) in enumerate(
Expand All @@ -113,7 +122,7 @@ def tokenize_prompt(self, prompt):
rejected_input_ids = input_ids
rejected_labels = labels
# pass the chosen prompt/row to the Prompter to get the formatted prompt
chosen_message_list = self.get_chosen_conversation_thread(prompt)
chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt)
input_ids = []
labels = []
for _, (part, label) in enumerate(
Expand Down

0 comments on commit c4af0c2

Please sign in to comment.