From a546cf02df1d48899d6f055643cbb71979436bc5 Mon Sep 17 00:00:00 2001 From: Parker Smith Date: Sun, 3 Mar 2024 17:03:29 -0700 Subject: [PATCH] [MLC-38] server: Allow server to take custom instructions into account --- server/server.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/server/server.py b/server/server.py index 6b6a632..21bdf37 100644 --- a/server/server.py +++ b/server/server.py @@ -98,6 +98,30 @@ def format_messages(messages, context): return messages +def add_instructions(messages, instructions): + personalization = instructions.get('personalization', '') + response = instructions.get('response', '') + if len(personalization) > 0: + messages[-1]['content'] = f""" +Before answering the prompt, remember that user wanted you to know: + +{personalization} + +{messages[-1]['content']} +""".strip() + + if len(response) > 0: + messages[-1]['content'] = f""" +Before answering the prompt, remember that the user asked for you to respond in the following way: + +{response} + +{messages[-1]['content']} +""".strip() + + return messages + + class APIHandler(BaseHTTPRequestHandler): def _set_headers(self, status_code=200): self.send_response(status_code) @@ -177,6 +201,7 @@ def query(self, body): directory = body.get('directory', None) messages = body.get('messages', []) + instructions = body.get('instructions', None) if directory: # emperically better than `similarity_search` @@ -191,6 +216,7 @@ def query(self, body): print(('\n'+'--'*10+'\n').join([ f'{doc.metadata}\n{doc.page_content}' for doc in docs]), flush=True) + add_instructions(messages, instructions) print(messages, flush=True) prompt = mx.array(_tokenizer.encode(_tokenizer.apply_chat_template( messages,