diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index 4f938f999e..ab89364e30 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -12,27 +12,25 @@ PreTrainedModel, PreTrainedTokenizerBase, StoppingCriteria, StoppingCriteriaList, TextStreamer) +DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.' -class ChatFormatter: - """A class for formatting the chat history. - Args: - system: The system prompt. If None, a default ChatML-formatted prompt is used. - user: The user prompt. If None, a default ChatML value is used. - assistant: The assistant prompt. If None, a default ChatML value is used. +class ChatMessage: + """A class that contains a chat message. - Attributes: - system: The system prompt. - user: The user prompt. - assistant: The assistant prompt. - response_prefix: The response prefix (anything before {} in the assistant format string) + Please see ChatML format for more information: + https://huggingface.co/docs/transformers/main/en/chat_templating#how-do-i-use-chat-templates """ - def __init__(self, system: str, user: str, assistant: str) -> None: - self.system = system if system else '<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|>\n' - self.user = user if user else '<|im_start|>user\n{}<|im_end|>\n' - self.assistant = assistant if assistant else '<|im_start|>assistant\n{}<|im_end|>\n' - self.response_prefix = self.assistant.split('{}')[0] + def __init__(self, role: str, content: str) -> None: + self.role = role + self.content = content + + def to_dict(self,) -> Dict[str, str]: + return {'role': self.role, 'content': self.content} + + def __repr__(self) -> str: + return f"{{ 'role': {self.role}, 'content': {self.content} }}" class Conversation: @@ -41,6 +39,7 @@ class Conversation: Args: model: The model to use for inference. tokenizer: The tokenizer to use for inference. + system_prompt: The system prompt to use for the conversation. chat_format: The chat format to use for the conversation. generate_kwargs: The keyword arguments to pass to `model.generate`. stop_tokens: The tokens to stop generation on. @@ -48,9 +47,9 @@ class Conversation: Attributes: model: The model to use for inference. tokenizer: The tokenizer to use for inference. - chat_format: The chat format to use for the conversation. streamer: The streamer to use for inference. generate_kwargs: The keyword arguments to pass to `model.generate`. + system_prompt: The system prompt used in the conversation chat. history: The conversation history. cli_instructions: The instructions to display to the user. """ @@ -58,14 +57,13 @@ class Conversation: def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, - chat_format: ChatFormatter, generate_kwargs: Dict[str, Any], + system_prompt: str, stop_tokens: Optional[List[str]] = None) -> None: if stop_tokens is None: stop_tokens = ['<|endoftext|>', '<|im_end|>'] self.model = model self.tokenizer = tokenizer - self.chat_format = chat_format stop_token_ids = self.tokenizer.convert_tokens_to_ids(stop_tokens) if len(stop_token_ids) != len(stop_tokens): @@ -94,44 +92,52 @@ def __call__(self, input_ids: torch.LongTensor, self.streamer, } self.history = [] + system_prompt_msg = ChatMessage('system', system_prompt) + self.history.append(system_prompt_msg) + self.cli_instructions = ( 'Enter your message below.\n- Hit return twice to send input to the model\n' + "- Type 'clear' to restart the conversation\n- Type 'history' to see the conversation\n" + - "- Type 'quit' to end\n- Type 'system' to change the system prompt\n" + "- Type 'history_fmt' to see the conversation\n- Type 'quit' to end\n- Type 'system' to change the system prompt\n" ) + def _history_to_chat_conversation(self) -> List[Dict[str, str]]: + msg_history = [] + for chat_msg in self.history: + msg_history.append(chat_msg.to_dict()) + return msg_history + def _history_as_formatted_str(self) -> str: - text = self.chat_format.system + ''.join([ - '\n'.join([ - self.chat_format.user.format(item[0]), - self.chat_format.assistant.format(item[1]), - ]) for item in self.history[:-1] - ]) - text += self.chat_format.user.format(self.history[-1][0]) - text += self.chat_format.response_prefix - return text + chat_conversation = self._history_to_chat_conversation() + return self.tokenizer.apply_chat_template( + chat_conversation, + tokenize=False, + add_generation_prompt=False, + ) def turn(self, user_inp: str) -> None: - self.history.append([user_inp, '']) - conversation = self._history_as_formatted_str() - input_ids = self.tokenizer(conversation, return_tensors='pt').input_ids - input_ids = input_ids.to(self.model.device) + self.history.append(ChatMessage('user', user_inp)) + chat_conversation = self._history_to_chat_conversation() + tokenized_chat = self.tokenizer.apply_chat_template( + chat_conversation, + tokenize=True, + add_generation_prompt=True, + return_tensors='pt') + tokenized_chat = tokenized_chat.to(self.model.device) # also stream to stdout maybe_synchronize() start = time.time() - print('Assistant:') - gkwargs = {**self.generate_kwargs, 'input_ids': input_ids} - # this will stream to stdout, but we need to keep track of the output_ids for saving history - output_ids = self.model.generate(**gkwargs) + print(f'Assistant:') + output_ids = self.model.generate(tokenized_chat, **self.generate_kwargs) maybe_synchronize() end = time.time() - print(f'took {end - start:.2f} seconds') - new_tokens = output_ids[0, len(input_ids[0]):] + print(f'\nTook {end - start:.2f} seconds') + new_tokens = output_ids[0, len(tokenized_chat[0]):] assistant_response = self.tokenizer.decode(new_tokens, skip_special_tokens=True) - self.history[-1][-1] = assistant_response + self.history.append(ChatMessage('assistant', assistant_response)) def __call__(self) -> None: print(self.cli_instructions) @@ -147,7 +153,7 @@ def __call__(self) -> None: if user_inp.lower() == 'quit': break elif user_inp.lower() == 'clear': - self.history = [] + self.history = self.history[:1] # keep system prompt continue elif user_inp == 'history': print(f'history: {self.history}') @@ -158,8 +164,7 @@ def __call__(self) -> None: elif user_inp == 'system': print('Enter a new system prompt:') new_system = input() - sys = f'<|im_start|>system\n{new_system.strip()}.<|im_end|>\n' - self.chat_format.system = sys + self.history[0].content = new_system continue self.turn(user_inp) @@ -249,9 +254,9 @@ def parse_args() -> Namespace: parser.add_argument('--device_map', type=str, default=None) parser.add_argument('--attn_impl', type=str, default=None) parser.add_argument('--seed', type=int, default=42) - parser.add_argument('--system_prompt', type=str, default=None) - parser.add_argument('--user_msg_fmt', type=str, default=None) - parser.add_argument('--assistant_msg_fmt', type=str, default=None) + parser.add_argument('--system_prompt', + type=str, + default=DEFAULT_SYSTEM_PROMPT) parser.add_argument( '--stop_tokens', type=str, @@ -363,13 +368,9 @@ def main(args: Namespace) -> None: autocast_context = nullcontext() print('NOT using autocast...') - chat_format = ChatFormatter(system=args.system_prompt, - user=args.user_msg_fmt, - assistant=args.assistant_msg_fmt) - conversation = Conversation(model=model, tokenizer=tokenizer, - chat_format=chat_format, + system_prompt=args.system_prompt, generate_kwargs=generate_kwargs, stop_tokens=args.stop_tokens.split()) @@ -378,7 +379,8 @@ def main(args: Namespace) -> None: print('Warming up...') with autocast_context: conversation.turn('Write a welcome message to the user.') - conversation.history = [] + conversation.history = conversation.history[: + 1] # keep system prompt print('Starting conversation...') with autocast_context: