Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade hf chat #1061

Merged
merged 10 commits into from
Apr 2, 2024
106 changes: 54 additions & 52 deletions scripts/inference/hf_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,25 @@
PreTrainedModel, PreTrainedTokenizerBase,
StoppingCriteria, StoppingCriteriaList, TextStreamer)

DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.'

class ChatFormatter:
"""A class for formatting the chat history.

Args:
system: The system prompt. If None, a default ChatML-formatted prompt is used.
user: The user prompt. If None, a default ChatML value is used.
assistant: The assistant prompt. If None, a default ChatML value is used.
class ChatMessage:
"""A class that contains a chat message.

Attributes:
system: The system prompt.
user: The user prompt.
assistant: The assistant prompt.
response_prefix: The response prefix (anything before {} in the assistant format string)
Please see ChatML format for more information:
https://huggingface.co/docs/transformers/main/en/chat_templating#how-do-i-use-chat-templates
"""

def __init__(self, system: str, user: str, assistant: str) -> None:
self.system = system if system else '<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|>\n'
self.user = user if user else '<|im_start|>user\n{}<|im_end|>\n'
self.assistant = assistant if assistant else '<|im_start|>assistant\n{}<|im_end|>\n'
self.response_prefix = self.assistant.split('{}')[0]
def __init__(self, role: str, content: str) -> None:
self.role = role
self.content = content

def to_dict(self,) -> Dict[str, str]:
return {'role': self.role, 'content': self.content}

def __repr__(self) -> str:
return f"{{ 'role': {self.role}, 'content': {self.content} }}"


class Conversation:
Expand All @@ -41,31 +39,31 @@ class Conversation:
Args:
model: The model to use for inference.
tokenizer: The tokenizer to use for inference.
system_prompt: The system prompt to use for the conversation.
chat_format: The chat format to use for the conversation.
generate_kwargs: The keyword arguments to pass to `model.generate`.
stop_tokens: The tokens to stop generation on.

Attributes:
model: The model to use for inference.
tokenizer: The tokenizer to use for inference.
chat_format: The chat format to use for the conversation.
streamer: The streamer to use for inference.
generate_kwargs: The keyword arguments to pass to `model.generate`.
system_prompt: The system prompt used in the conversation chat.
history: The conversation history.
cli_instructions: The instructions to display to the user.
"""

def __init__(self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
chat_format: ChatFormatter,
generate_kwargs: Dict[str, Any],
system_prompt: str,
stop_tokens: Optional[List[str]] = None) -> None:
if stop_tokens is None:
stop_tokens = ['<|endoftext|>', '<|im_end|>']
self.model = model
self.tokenizer = tokenizer
self.chat_format = chat_format

stop_token_ids = self.tokenizer.convert_tokens_to_ids(stop_tokens)
if len(stop_token_ids) != len(stop_tokens):
Expand Down Expand Up @@ -94,44 +92,52 @@ def __call__(self, input_ids: torch.LongTensor,
self.streamer,
}
self.history = []
system_prompt_msg = ChatMessage('system', system_prompt)
self.history.append(system_prompt_msg)

self.cli_instructions = (
'Enter your message below.\n- Hit return twice to send input to the model\n'
+
"- Type 'clear' to restart the conversation\n- Type 'history' to see the conversation\n"
+
"- Type 'quit' to end\n- Type 'system' to change the system prompt\n"
"- Type 'history_fmt' to see the conversation\n- Type 'quit' to end\n- Type 'system' to change the system prompt\n"
)

def _history_to_chat_conversation(self) -> List[Dict[str, str]]:
msg_history = []
for chat_msg in self.history:
msg_history.append(chat_msg.to_dict())
return msg_history

def _history_as_formatted_str(self) -> str:
text = self.chat_format.system + ''.join([
'\n'.join([
self.chat_format.user.format(item[0]),
self.chat_format.assistant.format(item[1]),
]) for item in self.history[:-1]
])
text += self.chat_format.user.format(self.history[-1][0])
text += self.chat_format.response_prefix
return text
chat_conversation = self._history_to_chat_conversation()
return self.tokenizer.apply_chat_template(
chat_conversation,
tokenize=False,
add_generation_prompt=False,
)

def turn(self, user_inp: str) -> None:
self.history.append([user_inp, ''])
conversation = self._history_as_formatted_str()
input_ids = self.tokenizer(conversation, return_tensors='pt').input_ids
input_ids = input_ids.to(self.model.device)
self.history.append(ChatMessage('user', user_inp))
chat_conversation = self._history_to_chat_conversation()
tokenized_chat = self.tokenizer.apply_chat_template(
chat_conversation,
tokenize=True,
add_generation_prompt=True,
return_tensors='pt')
tokenized_chat = tokenized_chat.to(self.model.device)
# also stream to stdout
maybe_synchronize()
start = time.time()
print('Assistant:')
gkwargs = {**self.generate_kwargs, 'input_ids': input_ids}
# this will stream to stdout, but we need to keep track of the output_ids for saving history
output_ids = self.model.generate(**gkwargs)
print(f'Assistant:')
output_ids = self.model.generate(tokenized_chat, **self.generate_kwargs)
maybe_synchronize()
end = time.time()
print(f'took {end - start:.2f} seconds')
new_tokens = output_ids[0, len(input_ids[0]):]
print(f'\nTook {end - start:.2f} seconds')
new_tokens = output_ids[0, len(tokenized_chat[0]):]
assistant_response = self.tokenizer.decode(new_tokens,
skip_special_tokens=True)
self.history[-1][-1] = assistant_response
self.history.append(ChatMessage('assistant', assistant_response))

def __call__(self) -> None:
print(self.cli_instructions)
Expand All @@ -147,7 +153,7 @@ def __call__(self) -> None:
if user_inp.lower() == 'quit':
break
elif user_inp.lower() == 'clear':
self.history = []
self.history = self.history[:1] # keep system prompt
continue
elif user_inp == 'history':
print(f'history: {self.history}')
Expand All @@ -158,8 +164,7 @@ def __call__(self) -> None:
elif user_inp == 'system':
print('Enter a new system prompt:')
new_system = input()
sys = f'<|im_start|>system\n{new_system.strip()}.<|im_end|>\n'
self.chat_format.system = sys
self.history[0].content = new_system
continue
self.turn(user_inp)

Expand Down Expand Up @@ -249,9 +254,9 @@ def parse_args() -> Namespace:
parser.add_argument('--device_map', type=str, default=None)
parser.add_argument('--attn_impl', type=str, default=None)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--system_prompt', type=str, default=None)
parser.add_argument('--user_msg_fmt', type=str, default=None)
parser.add_argument('--assistant_msg_fmt', type=str, default=None)
parser.add_argument('--system_prompt',
type=str,
default=DEFAULT_SYSTEM_PROMPT)
parser.add_argument(
'--stop_tokens',
type=str,
Expand Down Expand Up @@ -363,13 +368,9 @@ def main(args: Namespace) -> None:
autocast_context = nullcontext()
print('NOT using autocast...')

chat_format = ChatFormatter(system=args.system_prompt,
user=args.user_msg_fmt,
assistant=args.assistant_msg_fmt)

conversation = Conversation(model=model,
tokenizer=tokenizer,
chat_format=chat_format,
system_prompt=args.system_prompt,
generate_kwargs=generate_kwargs,
stop_tokens=args.stop_tokens.split())

Expand All @@ -378,7 +379,8 @@ def main(args: Namespace) -> None:
print('Warming up...')
with autocast_context:
conversation.turn('Write a welcome message to the user.')
conversation.history = []
conversation.history = conversation.history[:
1] # keep system prompt

print('Starting conversation...')
with autocast_context:
Expand Down
Loading