From 993663a3edd6c646da06d1972a04e2f9953a9230 Mon Sep 17 00:00:00 2001 From: Chuck Tang Date: Mon, 25 Mar 2024 22:57:13 +0000 Subject: [PATCH 1/4] upgrade hf chat --- scripts/inference/hf_chat.py | 92 +++++++++++++++++------------------- 1 file changed, 43 insertions(+), 49 deletions(-) diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index 4f938f999e..cebc9e15f5 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -13,26 +13,23 @@ StoppingCriteria, StoppingCriteriaList, TextStreamer) -class ChatFormatter: - """A class for formatting the chat history. +DEFAULT_SYSTEM_PROMPT = "You are a friendly chatbot who aims to be helpful and honest." +class ChatMessage: + """ A class that holds a chat message. - Args: - system: The system prompt. If None, a default ChatML-formatted prompt is used. - user: The user prompt. If None, a default ChatML value is used. - assistant: The assistant prompt. If None, a default ChatML value is used. - - Attributes: - system: The system prompt. - user: The user prompt. - assistant: The assistant prompt. - response_prefix: The response prefix (anything before {} in the assistant format string) + Please see ChatML format for more information: + https://huggingface.co/docs/transformers/main/en/chat_templating#how-do-i-use-chat-templates """ - def __init__(self, system: str, user: str, assistant: str) -> None: - self.system = system if system else '<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|>\n' - self.user = user if user else '<|im_start|>user\n{}<|im_end|>\n' - self.assistant = assistant if assistant else '<|im_start|>assistant\n{}<|im_end|>\n' - self.response_prefix = self.assistant.split('{}')[0] + def __init__(self, role: str, content: str) -> None: + self.role = role + self.content = content + + def to_dict(self,) -> Dict[str, str]: + return {'role': self.role, 'content': self.content} + + def __repr__(self) -> str: + return f"{{ 'role': {self.role}, 'content': {self.content} }}" class Conversation: @@ -41,6 +38,7 @@ class Conversation: Args: model: The model to use for inference. tokenizer: The tokenizer to use for inference. + system_prompt: The system prompt to use for the conversation. chat_format: The chat format to use for the conversation. generate_kwargs: The keyword arguments to pass to `model.generate`. stop_tokens: The tokens to stop generation on. @@ -58,14 +56,13 @@ class Conversation: def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, - chat_format: ChatFormatter, generate_kwargs: Dict[str, Any], + system_prompt: str, stop_tokens: Optional[List[str]] = None) -> None: if stop_tokens is None: stop_tokens = ['<|endoftext|>', '<|im_end|>'] self.model = model self.tokenizer = tokenizer - self.chat_format = chat_format stop_token_ids = self.tokenizer.convert_tokens_to_ids(stop_tokens) if len(stop_token_ids) != len(stop_tokens): @@ -94,6 +91,9 @@ def __call__(self, input_ids: torch.LongTensor, self.streamer, } self.history = [] + system_prompt_msg = ChatMessage('system', system_prompt) + self.history.append(system_prompt_msg) + self.cli_instructions = ( 'Enter your message below.\n- Hit return twice to send input to the model\n' + @@ -102,36 +102,35 @@ def __call__(self, input_ids: torch.LongTensor, "- Type 'quit' to end\n- Type 'system' to change the system prompt\n" ) + def _history_to_chat(self) -> List[Dict[str, str]]: + msg_history = [] + for chat_msg in self.history: + msg_history.append(chat_msg.to_dict()) + return msg_history + def _history_as_formatted_str(self) -> str: - text = self.chat_format.system + ''.join([ - '\n'.join([ - self.chat_format.user.format(item[0]), - self.chat_format.assistant.format(item[1]), - ]) for item in self.history[:-1] - ]) - text += self.chat_format.user.format(self.history[-1][0]) - text += self.chat_format.response_prefix - return text + chat_conversation = self._history_to_chat() + tokenized_chat = self.tokenizer.apply_chat_template(chat_conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt") + tokenized_chat = tokenized_chat.to(self.model.device) + return self.tokenizer.decode(tokenized_chat[0]) def turn(self, user_inp: str) -> None: - self.history.append([user_inp, '']) - conversation = self._history_as_formatted_str() - input_ids = self.tokenizer(conversation, return_tensors='pt').input_ids - input_ids = input_ids.to(self.model.device) + self.history.append(ChatMessage('user', user_inp)) + chat_conversation = self._history_to_chat() + tokenized_chat = self.tokenizer.apply_chat_template(chat_conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt") + tokenized_chat = tokenized_chat.to(self.model.device) # also stream to stdout maybe_synchronize() start = time.time() - print('Assistant:') - gkwargs = {**self.generate_kwargs, 'input_ids': input_ids} - # this will stream to stdout, but we need to keep track of the output_ids for saving history - output_ids = self.model.generate(**gkwargs) + print(f'Assistant {self.model.name_or_path}:') + output_ids = self.model.generate(tokenized_chat, **self.generate_kwargs) maybe_synchronize() end = time.time() - print(f'took {end - start:.2f} seconds') - new_tokens = output_ids[0, len(input_ids[0]):] + print(f'\nTook {end - start:.2f} seconds') + new_tokens = output_ids[0, len(tokenized_chat[0]):] assistant_response = self.tokenizer.decode(new_tokens, skip_special_tokens=True) - self.history[-1][-1] = assistant_response + self.history.append(ChatMessage('assistant', assistant_response)) def __call__(self) -> None: print(self.cli_instructions) @@ -147,7 +146,7 @@ def __call__(self) -> None: if user_inp.lower() == 'quit': break elif user_inp.lower() == 'clear': - self.history = [] + self.history = self.history[:1] # keep system prompt continue elif user_inp == 'history': print(f'history: {self.history}') @@ -158,8 +157,7 @@ def __call__(self) -> None: elif user_inp == 'system': print('Enter a new system prompt:') new_system = input() - sys = f'<|im_start|>system\n{new_system.strip()}.<|im_end|>\n' - self.chat_format.system = sys + self.history[0].content = new_system continue self.turn(user_inp) @@ -249,7 +247,7 @@ def parse_args() -> Namespace: parser.add_argument('--device_map', type=str, default=None) parser.add_argument('--attn_impl', type=str, default=None) parser.add_argument('--seed', type=int, default=42) - parser.add_argument('--system_prompt', type=str, default=None) + parser.add_argument('--system_prompt', type=str, default=DEFAULT_SYSTEM_PROMPT) parser.add_argument('--user_msg_fmt', type=str, default=None) parser.add_argument('--assistant_msg_fmt', type=str, default=None) parser.add_argument( @@ -363,13 +361,9 @@ def main(args: Namespace) -> None: autocast_context = nullcontext() print('NOT using autocast...') - chat_format = ChatFormatter(system=args.system_prompt, - user=args.user_msg_fmt, - assistant=args.assistant_msg_fmt) - conversation = Conversation(model=model, tokenizer=tokenizer, - chat_format=chat_format, + system_prompt=args.system_prompt, generate_kwargs=generate_kwargs, stop_tokens=args.stop_tokens.split()) @@ -378,7 +372,7 @@ def main(args: Namespace) -> None: print('Warming up...') with autocast_context: conversation.turn('Write a welcome message to the user.') - conversation.history = [] + conversation.history = conversation.history[:1] # keep system prompt print('Starting conversation...') with autocast_context: From a20bee9095974485ca360e8f684695519bca3890 Mon Sep 17 00:00:00 2001 From: Chuck Tang Date: Mon, 25 Mar 2024 23:01:02 +0000 Subject: [PATCH 2/4] fmt --- scripts/inference/hf_chat.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index cebc9e15f5..027a23c7b6 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -12,10 +12,11 @@ PreTrainedModel, PreTrainedTokenizerBase, StoppingCriteria, StoppingCriteriaList, TextStreamer) +DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.' + -DEFAULT_SYSTEM_PROMPT = "You are a friendly chatbot who aims to be helpful and honest." class ChatMessage: - """ A class that holds a chat message. + """A class that contains a chat message. Please see ChatML format for more information: https://huggingface.co/docs/transformers/main/en/chat_templating#how-do-i-use-chat-templates @@ -110,14 +111,22 @@ def _history_to_chat(self) -> List[Dict[str, str]]: def _history_as_formatted_str(self) -> str: chat_conversation = self._history_to_chat() - tokenized_chat = self.tokenizer.apply_chat_template(chat_conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt") + tokenized_chat = self.tokenizer.apply_chat_template( + chat_conversation, + tokenize=True, + add_generation_prompt=True, + return_tensors='pt') tokenized_chat = tokenized_chat.to(self.model.device) return self.tokenizer.decode(tokenized_chat[0]) def turn(self, user_inp: str) -> None: self.history.append(ChatMessage('user', user_inp)) chat_conversation = self._history_to_chat() - tokenized_chat = self.tokenizer.apply_chat_template(chat_conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt") + tokenized_chat = self.tokenizer.apply_chat_template( + chat_conversation, + tokenize=True, + add_generation_prompt=True, + return_tensors='pt') tokenized_chat = tokenized_chat.to(self.model.device) # also stream to stdout maybe_synchronize() @@ -247,7 +256,9 @@ def parse_args() -> Namespace: parser.add_argument('--device_map', type=str, default=None) parser.add_argument('--attn_impl', type=str, default=None) parser.add_argument('--seed', type=int, default=42) - parser.add_argument('--system_prompt', type=str, default=DEFAULT_SYSTEM_PROMPT) + parser.add_argument('--system_prompt', + type=str, + default=DEFAULT_SYSTEM_PROMPT) parser.add_argument('--user_msg_fmt', type=str, default=None) parser.add_argument('--assistant_msg_fmt', type=str, default=None) parser.add_argument( @@ -372,7 +383,8 @@ def main(args: Namespace) -> None: print('Warming up...') with autocast_context: conversation.turn('Write a welcome message to the user.') - conversation.history = conversation.history[:1] # keep system prompt + conversation.history = conversation.history[: + 1] # keep system prompt print('Starting conversation...') with autocast_context: From fc56b9a60fd3ae3a03cde4a0ea987586c9f671e5 Mon Sep 17 00:00:00 2001 From: Chuck Tang Date: Mon, 25 Mar 2024 23:17:41 +0000 Subject: [PATCH 3/4] fix --- scripts/inference/hf_chat.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index 027a23c7b6..141cfa4f23 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -47,9 +47,9 @@ class Conversation: Attributes: model: The model to use for inference. tokenizer: The tokenizer to use for inference. - chat_format: The chat format to use for the conversation. streamer: The streamer to use for inference. generate_kwargs: The keyword arguments to pass to `model.generate`. + system_prompt: The system prompt used in the conversation chat. history: The conversation history. cli_instructions: The instructions to display to the user. """ @@ -100,28 +100,26 @@ def __call__(self, input_ids: torch.LongTensor, + "- Type 'clear' to restart the conversation\n- Type 'history' to see the conversation\n" + - "- Type 'quit' to end\n- Type 'system' to change the system prompt\n" + "- Type 'history_fmt' to see the conversation\n- Type 'quit' to end\n- Type 'system' to change the system prompt\n" ) - def _history_to_chat(self) -> List[Dict[str, str]]: + def _history_to_chat_conversation(self) -> List[Dict[str, str]]: msg_history = [] for chat_msg in self.history: msg_history.append(chat_msg.to_dict()) return msg_history def _history_as_formatted_str(self) -> str: - chat_conversation = self._history_to_chat() - tokenized_chat = self.tokenizer.apply_chat_template( + chat_conversation = self._history_to_chat_conversation() + return self.tokenizer.apply_chat_template( chat_conversation, - tokenize=True, - add_generation_prompt=True, - return_tensors='pt') - tokenized_chat = tokenized_chat.to(self.model.device) - return self.tokenizer.decode(tokenized_chat[0]) + tokenize=False, + add_generation_prompt=False, + ) def turn(self, user_inp: str) -> None: self.history.append(ChatMessage('user', user_inp)) - chat_conversation = self._history_to_chat() + chat_conversation = self._history_to_chat_conversation() tokenized_chat = self.tokenizer.apply_chat_template( chat_conversation, tokenize=True, @@ -131,7 +129,7 @@ def turn(self, user_inp: str) -> None: # also stream to stdout maybe_synchronize() start = time.time() - print(f'Assistant {self.model.name_or_path}:') + print(f'Assistant:') output_ids = self.model.generate(tokenized_chat, **self.generate_kwargs) maybe_synchronize() end = time.time() From c25b41d75f949831e66283e6a5ea83f4b3608a4a Mon Sep 17 00:00:00 2001 From: Chuck Tang Date: Mon, 1 Apr 2024 10:29:08 -0700 Subject: [PATCH 4/4] commit change --- scripts/inference/hf_chat.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index 141cfa4f23..ab89364e30 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -257,8 +257,6 @@ def parse_args() -> Namespace: parser.add_argument('--system_prompt', type=str, default=DEFAULT_SYSTEM_PROMPT) - parser.add_argument('--user_msg_fmt', type=str, default=None) - parser.add_argument('--assistant_msg_fmt', type=str, default=None) parser.add_argument( '--stop_tokens', type=str,