Skip to content

Commit

Permalink
Upgrade hf chat (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
j316chuck authored Apr 2, 2024
1 parent 580a4b0 commit 394735b
Showing 1 changed file with 54 additions and 52 deletions.
106 changes: 54 additions & 52 deletions scripts/inference/hf_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,25 @@
PreTrainedModel, PreTrainedTokenizerBase,
StoppingCriteria, StoppingCriteriaList, TextStreamer)

DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.'

class ChatFormatter:
"""A class for formatting the chat history.

Args:
system: The system prompt. If None, a default ChatML-formatted prompt is used.
user: The user prompt. If None, a default ChatML value is used.
assistant: The assistant prompt. If None, a default ChatML value is used.
class ChatMessage:
"""A class that contains a chat message.
Attributes:
system: The system prompt.
user: The user prompt.
assistant: The assistant prompt.
response_prefix: The response prefix (anything before {} in the assistant format string)
Please see ChatML format for more information:
https://huggingface.co/docs/transformers/main/en/chat_templating#how-do-i-use-chat-templates
"""

def __init__(self, system: str, user: str, assistant: str) -> None:
self.system = system if system else '<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|>\n'
self.user = user if user else '<|im_start|>user\n{}<|im_end|>\n'
self.assistant = assistant if assistant else '<|im_start|>assistant\n{}<|im_end|>\n'
self.response_prefix = self.assistant.split('{}')[0]
def __init__(self, role: str, content: str) -> None:
self.role = role
self.content = content

def to_dict(self,) -> Dict[str, str]:
return {'role': self.role, 'content': self.content}

def __repr__(self) -> str:
return f"{{ 'role': {self.role}, 'content': {self.content} }}"


class Conversation:
Expand All @@ -41,31 +39,31 @@ class Conversation:
Args:
model: The model to use for inference.
tokenizer: The tokenizer to use for inference.
system_prompt: The system prompt to use for the conversation.
chat_format: The chat format to use for the conversation.
generate_kwargs: The keyword arguments to pass to `model.generate`.
stop_tokens: The tokens to stop generation on.
Attributes:
model: The model to use for inference.
tokenizer: The tokenizer to use for inference.
chat_format: The chat format to use for the conversation.
streamer: The streamer to use for inference.
generate_kwargs: The keyword arguments to pass to `model.generate`.
system_prompt: The system prompt used in the conversation chat.
history: The conversation history.
cli_instructions: The instructions to display to the user.
"""

def __init__(self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
chat_format: ChatFormatter,
generate_kwargs: Dict[str, Any],
system_prompt: str,
stop_tokens: Optional[List[str]] = None) -> None:
if stop_tokens is None:
stop_tokens = ['<|endoftext|>', '<|im_end|>']
self.model = model
self.tokenizer = tokenizer
self.chat_format = chat_format

stop_token_ids = self.tokenizer.convert_tokens_to_ids(stop_tokens)
if len(stop_token_ids) != len(stop_tokens):
Expand Down Expand Up @@ -94,44 +92,52 @@ def __call__(self, input_ids: torch.LongTensor,
self.streamer,
}
self.history = []
system_prompt_msg = ChatMessage('system', system_prompt)
self.history.append(system_prompt_msg)

self.cli_instructions = (
'Enter your message below.\n- Hit return twice to send input to the model\n'
+
"- Type 'clear' to restart the conversation\n- Type 'history' to see the conversation\n"
+
"- Type 'quit' to end\n- Type 'system' to change the system prompt\n"
"- Type 'history_fmt' to see the conversation\n- Type 'quit' to end\n- Type 'system' to change the system prompt\n"
)

def _history_to_chat_conversation(self) -> List[Dict[str, str]]:
msg_history = []
for chat_msg in self.history:
msg_history.append(chat_msg.to_dict())
return msg_history

def _history_as_formatted_str(self) -> str:
text = self.chat_format.system + ''.join([
'\n'.join([
self.chat_format.user.format(item[0]),
self.chat_format.assistant.format(item[1]),
]) for item in self.history[:-1]
])
text += self.chat_format.user.format(self.history[-1][0])
text += self.chat_format.response_prefix
return text
chat_conversation = self._history_to_chat_conversation()
return self.tokenizer.apply_chat_template(
chat_conversation,
tokenize=False,
add_generation_prompt=False,
)

def turn(self, user_inp: str) -> None:
self.history.append([user_inp, ''])
conversation = self._history_as_formatted_str()
input_ids = self.tokenizer(conversation, return_tensors='pt').input_ids
input_ids = input_ids.to(self.model.device)
self.history.append(ChatMessage('user', user_inp))
chat_conversation = self._history_to_chat_conversation()
tokenized_chat = self.tokenizer.apply_chat_template(
chat_conversation,
tokenize=True,
add_generation_prompt=True,
return_tensors='pt')
tokenized_chat = tokenized_chat.to(self.model.device)
# also stream to stdout
maybe_synchronize()
start = time.time()
print('Assistant:')
gkwargs = {**self.generate_kwargs, 'input_ids': input_ids}
# this will stream to stdout, but we need to keep track of the output_ids for saving history
output_ids = self.model.generate(**gkwargs)
print(f'Assistant:')
output_ids = self.model.generate(tokenized_chat, **self.generate_kwargs)
maybe_synchronize()
end = time.time()
print(f'took {end - start:.2f} seconds')
new_tokens = output_ids[0, len(input_ids[0]):]
print(f'\nTook {end - start:.2f} seconds')
new_tokens = output_ids[0, len(tokenized_chat[0]):]
assistant_response = self.tokenizer.decode(new_tokens,
skip_special_tokens=True)
self.history[-1][-1] = assistant_response
self.history.append(ChatMessage('assistant', assistant_response))

def __call__(self) -> None:
print(self.cli_instructions)
Expand All @@ -147,7 +153,7 @@ def __call__(self) -> None:
if user_inp.lower() == 'quit':
break
elif user_inp.lower() == 'clear':
self.history = []
self.history = self.history[:1] # keep system prompt
continue
elif user_inp == 'history':
print(f'history: {self.history}')
Expand All @@ -158,8 +164,7 @@ def __call__(self) -> None:
elif user_inp == 'system':
print('Enter a new system prompt:')
new_system = input()
sys = f'<|im_start|>system\n{new_system.strip()}.<|im_end|>\n'
self.chat_format.system = sys
self.history[0].content = new_system
continue
self.turn(user_inp)

Expand Down Expand Up @@ -249,9 +254,9 @@ def parse_args() -> Namespace:
parser.add_argument('--device_map', type=str, default=None)
parser.add_argument('--attn_impl', type=str, default=None)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--system_prompt', type=str, default=None)
parser.add_argument('--user_msg_fmt', type=str, default=None)
parser.add_argument('--assistant_msg_fmt', type=str, default=None)
parser.add_argument('--system_prompt',
type=str,
default=DEFAULT_SYSTEM_PROMPT)
parser.add_argument(
'--stop_tokens',
type=str,
Expand Down Expand Up @@ -363,13 +368,9 @@ def main(args: Namespace) -> None:
autocast_context = nullcontext()
print('NOT using autocast...')

chat_format = ChatFormatter(system=args.system_prompt,
user=args.user_msg_fmt,
assistant=args.assistant_msg_fmt)

conversation = Conversation(model=model,
tokenizer=tokenizer,
chat_format=chat_format,
system_prompt=args.system_prompt,
generate_kwargs=generate_kwargs,
stop_tokens=args.stop_tokens.split())

Expand All @@ -378,7 +379,8 @@ def main(args: Namespace) -> None:
print('Warming up...')
with autocast_context:
conversation.turn('Write a welcome message to the user.')
conversation.history = []
conversation.history = conversation.history[:
1] # keep system prompt

print('Starting conversation...')
with autocast_context:
Expand Down

0 comments on commit 394735b

Please sign in to comment.