diff --git a/server/server.py b/server/server.py index 6b6a632..21bdf37 100644 --- a/server/server.py +++ b/server/server.py @@ -98,6 +98,30 @@ def format_messages(messages, context): return messages +def add_instructions(messages, instructions): + personalization = instructions.get('personalization', '') + response = instructions.get('response', '') + if len(personalization) > 0: + messages[-1]['content'] = f""" +Before answering the prompt, remember that user wanted you to know: + +{personalization} + +{messages[-1]['content']} +""".strip() + + if len(response) > 0: + messages[-1]['content'] = f""" +Before answering the prompt, remember that the user asked for you to respond in the following way: + +{response} + +{messages[-1]['content']} +""".strip() + + return messages + + class APIHandler(BaseHTTPRequestHandler): def _set_headers(self, status_code=200): self.send_response(status_code) @@ -177,6 +201,7 @@ def query(self, body): directory = body.get('directory', None) messages = body.get('messages', []) + instructions = body.get('instructions', None) if directory: # emperically better than `similarity_search` @@ -191,6 +216,7 @@ def query(self, body): print(('\n'+'--'*10+'\n').join([ f'{doc.metadata}\n{doc.page_content}' for doc in docs]), flush=True) + add_instructions(messages, instructions) print(messages, flush=True) prompt = mx.array(_tokenizer.encode(_tokenizer.apply_chat_template( messages,