From f4710e7ebc034be7acd1c5650486e8d4f57c6c1d Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 23 Aug 2024 12:05:16 +0800 Subject: [PATCH] add response_field in data_processing (#68) Signed-off-by: Yu Chin Fabian Lim --- scripts/benchmarks/benchmark.py | 2 ++ scripts/benchmarks/data_processing.py | 28 ++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index a43f34c8..d21b2fbe 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -166,6 +166,7 @@ def __init__( dataset_text_field: str = "output", chat_template: str = None, response_template: str = None, + response_field: str = None, additional_dataset_kwargs: Dict = {}, ) -> None: @@ -180,6 +181,7 @@ def __init__( "tokenize": tokenize, "input_field": input_field, "dataset_text_field": dataset_text_field, + "response_field": response_field, "chat_template": chat_template, } self.training_paths = {} # cache to store the training paths diff --git a/scripts/benchmarks/data_processing.py b/scripts/benchmarks/data_processing.py index 1a860bbe..3125ca44 100644 --- a/scripts/benchmarks/data_processing.py +++ b/scripts/benchmarks/data_processing.py @@ -1,5 +1,6 @@ # Standard from typing import Callable, Dict, List +import warnings # Third Party from transformers import PreTrainedTokenizer @@ -16,6 +17,7 @@ def build_data_formatting_func( dataset_text_field: str = "output", features: List = None, response_template: str = None, + response_field: str = None, chat_template: str = None, ): if tokenizer is None or chat_template is None: @@ -36,6 +38,7 @@ def build_data_formatting_func( dataset_text_field, features, response_template, + response_field, ) @@ -47,6 +50,8 @@ def _build_data_formatting_func( dataset_text_field: str = "output", features: List = None, response_template: str = None, + response_field: str = None, + ignore_index: int = -100, ): tokenizer.chat_template = chat_template @@ -54,12 +59,33 @@ def _build_data_formatting_func( loss_masking = None if tokenize and response_template is not None: loss_masking = instruction_mask_loss(tokenizer, response_template) + elif tokenize and response_template is None: + assert response_field is not None, \ + "response_field must be specified if tokenize=True and response_template=None." def _format(example): formatted_and_maybe_tokenized = tokenizer.apply_chat_template( [example], tokenize=tokenize ) key = "input_ids" if tokenize else dataset_text_field + + if tokenize and response_template is None and response_field: + # in this case we need to use the response field to tokenize + warnings.warn( + "chat_template passed in with tokenize=True and " + "response_template was None. To ensure loss masking is " + f"correct, please do not put reponse_field '{response_field}' " + "in the chat template." + ) + # NOTE: in this case not handling attention mask + response = tokenizer(example[response_field])['input_ids'] + return { + key: formatted_and_maybe_tokenized + response, + 'labels': [ ignore_index ] * len(formatted_and_maybe_tokenized) + response + } + + loss_masking = instruction_mask_loss(tokenizer, response_template) + if not loss_masking: return {key: formatted_and_maybe_tokenized} return loss_masking(formatted_and_maybe_tokenized) @@ -193,4 +219,4 @@ def collate_example(example): # flatten the additional dim return {k: v.view(-1) for k, v in collated_example.items()} - return collate_example + return collate_example \ No newline at end of file