diff --git a/server/server.py b/server/server.py index 6b6a632..1dd5e6a 100644 --- a/server/server.py +++ b/server/server.py @@ -98,6 +98,31 @@ def format_messages(messages, context): return messages +def add_instructions(messages, instructions): + personalization = instructions.get('personalization', '') + response = instructions.get('response', '') + if (messages[-2].role != 'system'): + # Make send to last message system message + messages[-1]['role'] = 'system' + if len(personalization) > 0: + messages[-1]['content'] = f""" +You are an assistant who knows the following about me: +{personalization} + +{messages[-1]['content']} +""".strip() + + if len(response) > 0: + messages[-1]['content'] = f""" +You are an assistant who responds based on the following specifications: +{response} + +{messages[-1]['content']} +""".strip() + + return messages + + class APIHandler(BaseHTTPRequestHandler): def _set_headers(self, status_code=200): self.send_response(status_code) @@ -177,6 +202,7 @@ def query(self, body): directory = body.get('directory', None) messages = body.get('messages', []) + instructions = body.get('instructions', None) if directory: # emperically better than `similarity_search` @@ -191,6 +217,7 @@ def query(self, body): print(('\n'+'--'*10+'\n').join([ f'{doc.metadata}\n{doc.page_content}' for doc in docs]), flush=True) + add_instructions(messages, instructions) print(messages, flush=True) prompt = mx.array(_tokenizer.encode(_tokenizer.apply_chat_template( messages,