From 4a3fcf92eed303c2577f0d51c08d8e887e48c093 Mon Sep 17 00:00:00 2001 From: Parker Smith Date: Sun, 3 Mar 2024 17:03:29 -0700 Subject: [PATCH] [MLC-38] server: Allow server to take custom instructions into account --- server/server.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/server/server.py b/server/server.py index 6b6a632..2868456 100644 --- a/server/server.py +++ b/server/server.py @@ -98,6 +98,28 @@ def format_messages(messages, context): return messages +def add_instructions(messages, instructions): + if len(instructions.get('personalization', '')) > 0: + messages[-1]['content'] = f""" +Before answering the prompt, remember that user wanted you to know: + +{instructions.get('personalization', '')} + +{messages[-1]['content']} +""".strip() + + if len(instructions.get('response', '')) > 0: + messages[-1]['content'] = f""" +Before answering the prompt, remember that the user asked for you to respond in the following way: + +{instructions.get('response', '')} + +{messages[-1]['content']} +""".strip() + + return messages + + class APIHandler(BaseHTTPRequestHandler): def _set_headers(self, status_code=200): self.send_response(status_code) @@ -177,6 +199,7 @@ def query(self, body): directory = body.get('directory', None) messages = body.get('messages', []) + instructions = body.get('instructions', None) if directory: # emperically better than `similarity_search` @@ -191,6 +214,7 @@ def query(self, body): print(('\n'+'--'*10+'\n').join([ f'{doc.metadata}\n{doc.page_content}' for doc in docs]), flush=True) + add_instructions(messages, instructions) print(messages, flush=True) prompt = mx.array(_tokenizer.encode(_tokenizer.apply_chat_template( messages,