Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code-interpreter-only training #107

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion functionary/prompt_template/base_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,11 +448,19 @@ def get_stop_tokens_for_generation(self) -> List[str]:
raise NotImplementedError

@abstractmethod
def get_assistant_prefixes(self) -> List[str]:
def get_assistant_prefixes_for_training_masking(
self, code_only: bool = False
) -> List[str]:
"""Return the assistant prefixs in the final prompt, this is used for masking the labels
in unmasking labels, the system will unmask chunks that start with assistant prefixs and end with stop tokens.
For example, assistant_prefixes might be: "<|from|>assistant\n<|recipient|>"
In this case unmasked chunks in labels would be tokens in ... of: <|from|>assistant\n<|recipient|> ... <|stop|>
When code_only is set to True, the assistant prefixes will include the python tool call.
For example, assistant_prefixes might be: "<|from|>assistant\n<|recipient|>python\n<content>"
In this case unmasked chunks in labels would just be the Python code in ... of: <|from|>assistant\n<|recipient|>python\n<|content|> ... <|stop|>
Args:
code_only (bool): if true, this method will return the prefix with python tool call

Returns:
List[str]: list of possible assistant prefixs
"""
Expand Down
4 changes: 3 additions & 1 deletion functionary/prompt_template/prompt_template_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def convert_message_to_prompt(self, message: Dict) -> str:
def get_stop_tokens_for_generation(self) -> List[str]:
return [self.end_assistant, self.end_function_call]

def get_assistant_prefixes(self) -> List[str]:
def get_assistant_prefixes_for_training_masking(
self, code_only: bool = False
) -> List[str]:
result = []
for item in [self.end_user, self.end_function]:
prefix = f"{item}\nassistant:"
Expand Down
10 changes: 8 additions & 2 deletions functionary/prompt_template/prompt_template_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,14 @@ def convert_message_to_prompt(self, message: Dict) -> str:
def get_stop_tokens_for_generation(self) -> List[str]:
return [self.stop_token]

def get_assistant_prefixes(self) -> List[str]:
return [f"{self.from_token}assistant\n{self.recipient_token}"]
def get_assistant_prefixes_for_training_masking(
self, code_only: bool = False
) -> List[str]:
prefix = f"{self.from_token}assistant\n{self.recipient_token}"
if code_only:
prefix += f"{self.predefined_func_names[PredefinedFuncTypes.code_interpreter]}\n{self.content_token}"

return [prefix]

def parse_assistant_response(
self, llm_output: str, tool_choice: Optional[Any] = None
Expand Down
31 changes: 27 additions & 4 deletions functionary/train/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_batch_indices(size: int, batch_size: int) -> List[Tuple[int, int]]:


def get_prefix_assistant_token_ids(
prompt_template: PromptTemplate, tokenizer: Any
prompt_template: PromptTemplate, tokenizer: Any, code_only: bool = False
) -> List[List[int]]:
"""Get prefix assistant token_ids for masking labels.
In message where role=assistant, content of assistant always start with a prefix, such as: "Assistant:" or "<|from|>assistant"
Expand All @@ -50,7 +50,9 @@ def get_prefix_assistant_token_ids(
List[List[int]]: List of token_ids of assistant prefixs
"""
result = []
for prefix in prompt_template.get_assistant_prefixes():
for prefix in prompt_template.get_assistant_prefixes_for_training_masking(
code_only=code_only
):
token_ids = tokenizer.encode(prefix, add_special_tokens=False)
if token_ids[0] == 29871:
token_ids = token_ids[1:]
Expand Down Expand Up @@ -109,7 +111,10 @@ def read_dataset(data_args, training_args, tokenizer, ds_type):
if data_ratio < 1:
raw_data = raw_data[: int(data_ratio * len(raw_data))]
ds = LazyPreprocessDataset(
raw_data, tokenizer, keep_assistant_prefix=keep_assistant_prefix
raw_data,
tokenizer,
keep_assistant_prefix=keep_assistant_prefix,
code_only=training_args.code_only,
)
return ds

Expand Down Expand Up @@ -147,6 +152,7 @@ def read_dataset(data_args, training_args, tokenizer, ds_type):
cached_folder=cached_folder,
ignore_cached=False,
keep_assistant_prefix=keep_assistant_prefix,
code_only=training_args.code_only,
use_flash_attention=True,
pack_length=pack_length,
)
Expand All @@ -161,6 +167,7 @@ def read_dataset(data_args, training_args, tokenizer, ds_type):
tokenizer,
cached_folder=cached_folder,
ignore_cached=False,
code_only=training_args.code_only,
use_flash_attention=True,
pack_length=pack_length,
)
Expand All @@ -177,6 +184,7 @@ def prepare_training_inputs(
max_length: Optional[int] = None,
return_tensor: bool = True,
keep_assistant_prefix: bool = False,
code_only: bool = False,
verbose=False,
) -> Dict[str, Union[str, Dict]]:
"""This function is used to convert a data point into input that is ready for training.
Expand All @@ -189,6 +197,7 @@ def prepare_training_inputs(
max_length (Optional[int], optional): _description_. Defaults to None.
return_tensor (bool, optional): _description_. Defaults to True.
keep_assistant_prefix (bool, optional): _description_. Defaults to False.
code_only (bool, optional): _description_. Defaults to False.
verbose (bool, optional): _description_. Defaults to False.

Returns:
Expand All @@ -201,6 +210,7 @@ def prepare_training_inputs(
max_length=max_length,
return_tensor=return_tensor,
keep_assistant_prefix=keep_assistant_prefix,
code_only=code_only,
verbose=verbose,
)
return dict(
Expand Down Expand Up @@ -323,6 +333,7 @@ def prepare_training_inputs_batch(
max_length: Optional[int] = None,
return_tensor: bool = True,
keep_assistant_prefix: bool = False,
code_only: bool = False,
verbose=False,
) -> List[Dict[str, Union[str, Dict]]]:
"""This function is used for when you want to get a dictionary input for the model.forward.
Expand All @@ -335,6 +346,8 @@ def prepare_training_inputs_batch(
padding (str, optional): type of padding (longest, max_length), this is passed to tokenizer(). Defaults to "max_length".
max_length (Optional[int], optional): maximum number of tokens allowed in prompt. Defaults to None.
return_tensor (bool, optional): if true, the input_dic will be dictionary[str, Tensor] else dictionary[str, List[int]]. Defaults to True.
keep_assistant_prefix (bool, optional): if true, the label tokens will not mask the assistant prefixes. Defaults to False.
code_only (bool, optional): this is set to true if we want to train a code-only model. Defaults to False.
verbose (bool, optional): to print some useful information or not. Defaults to False.

Returns:
Expand All @@ -345,7 +358,9 @@ def prepare_training_inputs_batch(
# a dictionary mapping from end_token_ --> end_token
prompt_template = get_prompt_template_from_tokenizer(tokenizer)
assistant_stop_token_ids = get_assistant_stop_token_ids(prompt_template, tokenizer)
assistant_prefix_tokens = get_prefix_assistant_token_ids(prompt_template, tokenizer)
assistant_prefix_tokens = get_prefix_assistant_token_ids(
prompt_template, tokenizer, code_only
)

prompt_str_list = []
for messages in batch_messages:
Expand Down Expand Up @@ -405,6 +420,7 @@ def map_raw_data_to_input_dic(
padding: str,
batch_size: int = 5000,
keep_assistant_prefix: bool = False,
code_only: bool = False,
) -> List[Dict]:
"""This function is used to map list of raw_data to list of processed data points for packing
Args:
Expand All @@ -413,6 +429,7 @@ def map_raw_data_to_input_dic(
padding (str): _description_
batch_size (int, optional): _description_. Defaults to 5000.
keep_assistant_prefix (bool, optional): if we unmask assistant prefix in computing loss. Defaults to False.
code_only (bool, optional): if true, we unmask assistant turns with python tool call only. Defaults to False.

Returns:
List[Dict]: _description_
Expand All @@ -428,6 +445,7 @@ def map_raw_data_to_input_dic(
padding=padding,
return_tensor=False,
keep_assistant_prefix=keep_assistant_prefix,
code_only=code_only,
)

assert len(batch_result["batch_inputs"]) == len(raw_data[start:end])
Expand Down Expand Up @@ -776,13 +794,15 @@ def __init__(
raw_data,
tokenizer: transformers.PreTrainedTokenizer,
keep_assistant_prefix: bool = False,
code_only: bool = False,
):
super().__init__()
self.tokenizer = tokenizer

self.raw_data = raw_data
self.cached_data_dict = {}
self.keep_assistant_prefix = keep_assistant_prefix
self.code_only = code_only

def __len__(self):
return len(self.raw_data)
Expand All @@ -795,6 +815,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
messages=self.raw_data[i],
tokenizer=self.tokenizer,
keep_assistant_prefix=self.keep_assistant_prefix,
code_only=self.code_only,
)
ret = {
"input_ids": ret["inputs"]["input_ids"],
Expand All @@ -816,6 +837,7 @@ def __init__(
ignore_cached: bool = False,
batch_size: int = 5000,
keep_assistant_prefix: bool = False,
code_only: bool = False,
use_flash_attention: bool = True,
pack_length: Optional[int] = None,
):
Expand All @@ -830,6 +852,7 @@ def __init__(
padding="do_not_pad",
batch_size=batch_size,
keep_assistant_prefix=keep_assistant_prefix,
code_only=code_only,
)
self.update_packing_info()
if cached_folder is not None:
Expand Down
7 changes: 6 additions & 1 deletion functionary/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,15 @@ class TrainingArguments(transformers.TrainingArguments):
"help": "Whether to mask the assistant prefix `<|from|>assistant\n<|recipient|>` during training"
},
)

prompt_template_version: str = field(
default="v2", metadata={"help": "choose prompt template to use for training"}
)
code_only: bool = field(
default=False,
metadata={
"help": "Whether the training session is code-only. If so, only assistant turns with tool calls to python tool will be unmasked"
},
)


def trainer_save_model_safe(trainer: transformers.Trainer):
Expand Down
98 changes: 98 additions & 0 deletions tests/test_case_v2_code_only.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
{
"tools": [
{
"type": "function",
"function": {
"name": "get_car_price",
"description": "Get the price of a particular car model",
"parameters": {
"type": "object",
"properties": {
"car_name": {
"type": "string",
"description": "The name of the car model"
}
},
"required": [
"car_name"
]
}
}
},
{
"type": "code_interpreter"
}

],
"messages": [
{
"role": "user",
"content": "who is the president of US"
},
{
"role": "assistant",
"content": "Biden is the president of US"
},
{
"role": "user",
"content": "is the car Song more expensive than car Tang?"
},
{
"role": "assistant",
"content": "I will get the price of 2 cars and compare",
"tool_calls": [
{
"function": {
"name": "get_car_price",
"arguments": "{\"car_name\": \"Song\"}"
}
},
{
"function": {
"name": "get_car_price",
"arguments": "{\"car_name\": \"Tang\"}"
}
}
]
},
{
"role": "tool",
"content": "{\"price\": {\"price\": \"$25000\"}}",
"name": "get_car_price"
},
{
"role": "tool",
"content": "{\"price\": {\"price\": \"$20000\"}}",
"name": "get_car_price"
},
{
"role": "assistant",
"content": "No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000."
},
{
"role": "user",
"content": "Can you provide me with the numpy code to perform elementwise multiplication of [1, 2] and [3, 4]?"
},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"function": {
"name": "python",
"arguments": "import numpy as np\n\n# Define the arrays\na = np.array([1, 2])\nb = np.array([3, 4])\n\n# Perform element-wise multiplication\nresult = np.multiply(a, b) # or simply a * b\n\nprint(result)"
}
}
]
},
{
"role": "tool",
"content": "{\"result\": \"[3, 8]\"}",
"name": "python"
},
{
"role": "assistant",
"content": "The numpy code is above and the result of the code is [3, 8]."
}
]
}
Loading