diff --git a/functionary/prompt_template/base_template.py b/functionary/prompt_template/base_template.py index aea2c43d..4b33f8f0 100644 --- a/functionary/prompt_template/base_template.py +++ b/functionary/prompt_template/base_template.py @@ -448,11 +448,19 @@ def get_stop_tokens_for_generation(self) -> List[str]: raise NotImplementedError @abstractmethod - def get_assistant_prefixes(self) -> List[str]: + def get_assistant_prefixes_for_training_masking( + self, code_only: bool = False + ) -> List[str]: """Return the assistant prefixs in the final prompt, this is used for masking the labels in unmasking labels, the system will unmask chunks that start with assistant prefixs and end with stop tokens. For example, assistant_prefixes might be: "<|from|>assistant\n<|recipient|>" In this case unmasked chunks in labels would be tokens in ... of: <|from|>assistant\n<|recipient|> ... <|stop|> + When code_only is set to True, the assistant prefixes will include the python tool call. + For example, assistant_prefixes might be: "<|from|>assistant\n<|recipient|>python\n" + In this case unmasked chunks in labels would just be the Python code in ... of: <|from|>assistant\n<|recipient|>python\n<|content|> ... <|stop|> + Args: + code_only (bool): if true, this method will return the prefix with python tool call + Returns: List[str]: list of possible assistant prefixs """ diff --git a/functionary/prompt_template/prompt_template_v1.py b/functionary/prompt_template/prompt_template_v1.py index b8133665..34faf1f1 100644 --- a/functionary/prompt_template/prompt_template_v1.py +++ b/functionary/prompt_template/prompt_template_v1.py @@ -116,7 +116,9 @@ def convert_message_to_prompt(self, message: Dict) -> str: def get_stop_tokens_for_generation(self) -> List[str]: return [self.end_assistant, self.end_function_call] - def get_assistant_prefixes(self) -> List[str]: + def get_assistant_prefixes_for_training_masking( + self, code_only: bool = False + ) -> List[str]: result = [] for item in [self.end_user, self.end_function]: prefix = f"{item}\nassistant:" diff --git a/functionary/prompt_template/prompt_template_v2.py b/functionary/prompt_template/prompt_template_v2.py index 9e3698de..511c4343 100644 --- a/functionary/prompt_template/prompt_template_v2.py +++ b/functionary/prompt_template/prompt_template_v2.py @@ -117,8 +117,14 @@ def convert_message_to_prompt(self, message: Dict) -> str: def get_stop_tokens_for_generation(self) -> List[str]: return [self.stop_token] - def get_assistant_prefixes(self) -> List[str]: - return [f"{self.from_token}assistant\n{self.recipient_token}"] + def get_assistant_prefixes_for_training_masking( + self, code_only: bool = False + ) -> List[str]: + prefix = f"{self.from_token}assistant\n{self.recipient_token}" + if code_only: + prefix += f"{self.predefined_func_names[PredefinedFuncTypes.code_interpreter]}\n{self.content_token}" + + return [prefix] def parse_assistant_response( self, llm_output: str, tool_choice: Optional[Any] = None diff --git a/functionary/train/custom_datasets.py b/functionary/train/custom_datasets.py index 6b4defb1..485eae14 100644 --- a/functionary/train/custom_datasets.py +++ b/functionary/train/custom_datasets.py @@ -37,7 +37,7 @@ def get_batch_indices(size: int, batch_size: int) -> List[Tuple[int, int]]: def get_prefix_assistant_token_ids( - prompt_template: PromptTemplate, tokenizer: Any + prompt_template: PromptTemplate, tokenizer: Any, code_only: bool = False ) -> List[List[int]]: """Get prefix assistant token_ids for masking labels. In message where role=assistant, content of assistant always start with a prefix, such as: "Assistant:" or "<|from|>assistant" @@ -50,7 +50,9 @@ def get_prefix_assistant_token_ids( List[List[int]]: List of token_ids of assistant prefixs """ result = [] - for prefix in prompt_template.get_assistant_prefixes(): + for prefix in prompt_template.get_assistant_prefixes_for_training_masking( + code_only=code_only + ): token_ids = tokenizer.encode(prefix, add_special_tokens=False) if token_ids[0] == 29871: token_ids = token_ids[1:] @@ -109,7 +111,10 @@ def read_dataset(data_args, training_args, tokenizer, ds_type): if data_ratio < 1: raw_data = raw_data[: int(data_ratio * len(raw_data))] ds = LazyPreprocessDataset( - raw_data, tokenizer, keep_assistant_prefix=keep_assistant_prefix + raw_data, + tokenizer, + keep_assistant_prefix=keep_assistant_prefix, + code_only=training_args.code_only, ) return ds @@ -147,6 +152,7 @@ def read_dataset(data_args, training_args, tokenizer, ds_type): cached_folder=cached_folder, ignore_cached=False, keep_assistant_prefix=keep_assistant_prefix, + code_only=training_args.code_only, use_flash_attention=True, pack_length=pack_length, ) @@ -161,6 +167,7 @@ def read_dataset(data_args, training_args, tokenizer, ds_type): tokenizer, cached_folder=cached_folder, ignore_cached=False, + code_only=training_args.code_only, use_flash_attention=True, pack_length=pack_length, ) @@ -177,6 +184,7 @@ def prepare_training_inputs( max_length: Optional[int] = None, return_tensor: bool = True, keep_assistant_prefix: bool = False, + code_only: bool = False, verbose=False, ) -> Dict[str, Union[str, Dict]]: """This function is used to convert a data point into input that is ready for training. @@ -189,6 +197,7 @@ def prepare_training_inputs( max_length (Optional[int], optional): _description_. Defaults to None. return_tensor (bool, optional): _description_. Defaults to True. keep_assistant_prefix (bool, optional): _description_. Defaults to False. + code_only (bool, optional): _description_. Defaults to False. verbose (bool, optional): _description_. Defaults to False. Returns: @@ -201,6 +210,7 @@ def prepare_training_inputs( max_length=max_length, return_tensor=return_tensor, keep_assistant_prefix=keep_assistant_prefix, + code_only=code_only, verbose=verbose, ) return dict( @@ -323,6 +333,7 @@ def prepare_training_inputs_batch( max_length: Optional[int] = None, return_tensor: bool = True, keep_assistant_prefix: bool = False, + code_only: bool = False, verbose=False, ) -> List[Dict[str, Union[str, Dict]]]: """This function is used for when you want to get a dictionary input for the model.forward. @@ -335,6 +346,8 @@ def prepare_training_inputs_batch( padding (str, optional): type of padding (longest, max_length), this is passed to tokenizer(). Defaults to "max_length". max_length (Optional[int], optional): maximum number of tokens allowed in prompt. Defaults to None. return_tensor (bool, optional): if true, the input_dic will be dictionary[str, Tensor] else dictionary[str, List[int]]. Defaults to True. + keep_assistant_prefix (bool, optional): if true, the label tokens will not mask the assistant prefixes. Defaults to False. + code_only (bool, optional): this is set to true if we want to train a code-only model. Defaults to False. verbose (bool, optional): to print some useful information or not. Defaults to False. Returns: @@ -345,7 +358,9 @@ def prepare_training_inputs_batch( # a dictionary mapping from end_token_ --> end_token prompt_template = get_prompt_template_from_tokenizer(tokenizer) assistant_stop_token_ids = get_assistant_stop_token_ids(prompt_template, tokenizer) - assistant_prefix_tokens = get_prefix_assistant_token_ids(prompt_template, tokenizer) + assistant_prefix_tokens = get_prefix_assistant_token_ids( + prompt_template, tokenizer, code_only + ) prompt_str_list = [] for messages in batch_messages: @@ -405,6 +420,7 @@ def map_raw_data_to_input_dic( padding: str, batch_size: int = 5000, keep_assistant_prefix: bool = False, + code_only: bool = False, ) -> List[Dict]: """This function is used to map list of raw_data to list of processed data points for packing Args: @@ -413,6 +429,7 @@ def map_raw_data_to_input_dic( padding (str): _description_ batch_size (int, optional): _description_. Defaults to 5000. keep_assistant_prefix (bool, optional): if we unmask assistant prefix in computing loss. Defaults to False. + code_only (bool, optional): if true, we unmask assistant turns with python tool call only. Defaults to False. Returns: List[Dict]: _description_ @@ -428,6 +445,7 @@ def map_raw_data_to_input_dic( padding=padding, return_tensor=False, keep_assistant_prefix=keep_assistant_prefix, + code_only=code_only, ) assert len(batch_result["batch_inputs"]) == len(raw_data[start:end]) @@ -776,6 +794,7 @@ def __init__( raw_data, tokenizer: transformers.PreTrainedTokenizer, keep_assistant_prefix: bool = False, + code_only: bool = False, ): super().__init__() self.tokenizer = tokenizer @@ -783,6 +802,7 @@ def __init__( self.raw_data = raw_data self.cached_data_dict = {} self.keep_assistant_prefix = keep_assistant_prefix + self.code_only = code_only def __len__(self): return len(self.raw_data) @@ -795,6 +815,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: messages=self.raw_data[i], tokenizer=self.tokenizer, keep_assistant_prefix=self.keep_assistant_prefix, + code_only=self.code_only, ) ret = { "input_ids": ret["inputs"]["input_ids"], @@ -816,6 +837,7 @@ def __init__( ignore_cached: bool = False, batch_size: int = 5000, keep_assistant_prefix: bool = False, + code_only: bool = False, use_flash_attention: bool = True, pack_length: Optional[int] = None, ): @@ -830,6 +852,7 @@ def __init__( padding="do_not_pad", batch_size=batch_size, keep_assistant_prefix=keep_assistant_prefix, + code_only=code_only, ) self.update_packing_info() if cached_folder is not None: diff --git a/functionary/train/train.py b/functionary/train/train.py index 00b36138..7b905c53 100644 --- a/functionary/train/train.py +++ b/functionary/train/train.py @@ -72,10 +72,15 @@ class TrainingArguments(transformers.TrainingArguments): "help": "Whether to mask the assistant prefix `<|from|>assistant\n<|recipient|>` during training" }, ) - prompt_template_version: str = field( default="v2", metadata={"help": "choose prompt template to use for training"} ) + code_only: bool = field( + default=False, + metadata={ + "help": "Whether the training session is code-only. If so, only assistant turns with tool calls to python tool will be unmasked" + }, + ) def trainer_save_model_safe(trainer: transformers.Trainer): diff --git a/tests/test_case_v2_code_only.json b/tests/test_case_v2_code_only.json new file mode 100644 index 00000000..c3426e95 --- /dev/null +++ b/tests/test_case_v2_code_only.json @@ -0,0 +1,98 @@ +{ + "tools": [ + { + "type": "function", + "function": { + "name": "get_car_price", + "description": "Get the price of a particular car model", + "parameters": { + "type": "object", + "properties": { + "car_name": { + "type": "string", + "description": "The name of the car model" + } + }, + "required": [ + "car_name" + ] + } + } + }, + { + "type": "code_interpreter" + } + + ], + "messages": [ + { + "role": "user", + "content": "who is the president of US" + }, + { + "role": "assistant", + "content": "Biden is the president of US" + }, + { + "role": "user", + "content": "is the car Song more expensive than car Tang?" + }, + { + "role": "assistant", + "content": "I will get the price of 2 cars and compare", + "tool_calls": [ + { + "function": { + "name": "get_car_price", + "arguments": "{\"car_name\": \"Song\"}" + } + }, + { + "function": { + "name": "get_car_price", + "arguments": "{\"car_name\": \"Tang\"}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"price\": {\"price\": \"$25000\"}}", + "name": "get_car_price" + }, + { + "role": "tool", + "content": "{\"price\": {\"price\": \"$20000\"}}", + "name": "get_car_price" + }, + { + "role": "assistant", + "content": "No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000." + }, + { + "role": "user", + "content": "Can you provide me with the numpy code to perform elementwise multiplication of [1, 2] and [3, 4]?" + }, + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "function": { + "name": "python", + "arguments": "import numpy as np\n\n# Define the arrays\na = np.array([1, 2])\nb = np.array([3, 4])\n\n# Perform element-wise multiplication\nresult = np.multiply(a, b) # or simply a * b\n\nprint(result)" + } + } + ] + }, + { + "role": "tool", + "content": "{\"result\": \"[3, 8]\"}", + "name": "python" + }, + { + "role": "assistant", + "content": "The numpy code is above and the result of the code is [3, 8]." + } + ] +} diff --git a/tests/test_prompt_creation.py b/tests/test_prompt_creation.py index c6ba6606..2d5d8d26 100644 --- a/tests/test_prompt_creation.py +++ b/tests/test_prompt_creation.py @@ -1,3 +1,4 @@ +import json import os import re import unittest @@ -6,19 +7,18 @@ from transformers import LlamaTokenizer, LlamaTokenizerFast from functionary.prompt_template import ( - get_prompt_template_by_version, + SYSTEM_MESSAGE, PromptTemplateV1, PromptTemplateV2, - SYSTEM_MESSAGE, + get_prompt_template_by_version, ) from functionary.schema import generate_schema_from_functions from functionary.train.custom_datasets import prepare_training_inputs -import json def extract_unmasked_chunks(labels: List[int]) -> List[List[int]]: """This function is used to extract unmasked chunks of integer - For example, labels = [-100, -100, 1, 2, 3, -100, -100, 4, 5] --> chunks = [[1,2,3], [4,5]] + For example, labels = [-100, -100, 1, 2, 3, -100, -100, 4, 5] --> chunks = [[1,2,3], [4,5]] Args: labels (List[int]): list of integer containing token_id and -100 @@ -47,10 +47,21 @@ def __init__(self, *args, **kwargs): self.prompt_template = get_prompt_template_by_version(self.template_version) current_folder = os.path.dirname(os.path.abspath(__file__)) - with open(os.path.join(current_folder, f"test_case_{self.template_version}.json")) as f: + with open( + os.path.join(current_folder, f"test_case_{self.template_version}.json") + ) as f: self.test_case = json.loads(f.read()) - with open(os.path.join(current_folder, f"prompt_test_{self.template_version}.txt")) as f: + with open( + os.path.join( + current_folder, f"test_case_{self.template_version}_code_only.json" + ) + ) as f: + self.test_case_code_only = json.loads(f.read()) + + with open( + os.path.join(current_folder, f"prompt_test_{self.template_version}.txt") + ) as f: self.final_prompt = f.read() self.final_prompt = self.final_prompt.replace("\n\n<|from|>", "\n<|from|>") @@ -90,8 +101,22 @@ def test_prepare_training_inputs_normal_tokenizer(self): keep_assistant_prefix=keep_assistant_prefix, ) + def test_prepare_training_inputs_code_only(self): + print("start testing code-only") + for keep_assistant_prefix in [False]: + self.run_prepare_training_inputs( + use_fast=True, + pretrained="mistralai/Mistral-7B-v0.1", + keep_assistant_prefix=keep_assistant_prefix, + code_only=True, + ) + def run_prepare_training_inputs( - self, use_fast: bool, pretrained: str, keep_assistant_prefix: bool = False + self, + use_fast: bool, + pretrained: str, + keep_assistant_prefix: bool = False, + code_only: bool = False, ): """this function is used to test function: prepare_training_inputs""" # note that must set legacy=True, read more: https://github.com/huggingface/transformers/issues/25176 @@ -109,14 +134,20 @@ def run_prepare_training_inputs( # check if tokenizer added new stop tokens successfully self.assertEqual(length_before + len(added_tokens), length_after) + if code_only is False: + messages = self.test_case + else: + messages = self.test_case_code_only + inputs = prepare_training_inputs( - messages=self.test_case, + messages=messages, tokenizer=tokenizer, padding="longest", max_length=1024, return_tensor=False, verbose=False, keep_assistant_prefix=keep_assistant_prefix, + code_only=code_only, ) input_ids = inputs["inputs"]["input_ids"] labels = inputs["inputs"]["labels"] @@ -132,9 +163,20 @@ def run_prepare_training_inputs( ) # Check if only messages where role=assistant are remained, others will be masked as -100 - assistant_message = [ - item for item in self.test_case["messages"] if item["role"] == "assistant" - ] + if code_only is False: + assistant_message = [ + item + for item in self.test_case["messages"] + if item["role"] == "assistant" + ] + else: + assistant_message = [ + item + for item in self.test_case_code_only["messages"] + if item["role"] == "assistant" + and "tool_calls" in item + and item["tool_calls"][0]["function"]["name"] == "python" + ] # find unmasked chunks in labels (chunk[i] != -100), there chunks are associated with assistant messages # for example: labels=[-100, -100, 1, 2, 3, -100, -100, 4, 5] --> chunks = [[1,2,3], [4,5]] chunks = extract_unmasked_chunks(labels) @@ -148,7 +190,11 @@ def run_prepare_training_inputs( if keep_assistant_prefix: prefix = "" else: - prefix = prompt_template.convert_message_to_prompt({"role": "assistant"}) + prefix = prompt_template.convert_message_to_prompt( + {"role": "assistant"} + ) + if code_only is True: + prefix += "python\n<|content|>" decoded_content = prefix + tokenizer.decode( chunk ) # note that need to add: "\nassistant" because we mask this, see line 194 in prompt_utils.py