-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create http endpoint for chat for testing
- Loading branch information
Showing
6 changed files
with
463 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
|
||
import json | ||
import logging | ||
import os | ||
from http_event_config import HTTPEventConfig | ||
from helpers.http_response import HTTPResponse | ||
from honeybadger import honeybadger | ||
|
||
honeybadger.configure() | ||
logging.getLogger('honeybadger').addHandler(logging.StreamHandler()) | ||
|
||
RESPONSE_TYPES = { | ||
"base": ["answer", "ref"], | ||
"debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"], | ||
"log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "size", "source_documents", "temperature", "token_counts"], | ||
"error": ["question", "error", "source_documents"] | ||
} | ||
|
||
def handler(event, context): | ||
print(f'Event: {event}') | ||
|
||
config = HTTPEventConfig(event) | ||
|
||
if not config.is_logged_in: | ||
return {"statusCode": 401, "body": "Unauthorized"} | ||
|
||
if config.question is None or config.question == "": | ||
return {"statusCode": 400, "body": "Question cannot be blank"} | ||
|
||
if not os.getenv("SKIP_WEAVIATE_SETUP"): | ||
config.setup_llm_request() | ||
response = HTTPResponse(config) | ||
final_response = response.prepare_response() | ||
if "error" in final_response: | ||
logging.error(f'Error: {final_response["error"]}') | ||
return {"statusCode": 500, "body": "Internal Server Error"} | ||
else: | ||
return {"statusCode": 200, "body": json.dumps(reshape_response(final_response, 'debug' if config.debug_mode else 'base'))} | ||
|
||
return {"statusCode": 200} | ||
|
||
def reshape_response(response, type): | ||
return {k: response[k] for k in RESPONSE_TYPES[type]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from helpers.metrics import debug_response | ||
from langchain_core.output_parsers import StrOutputParser | ||
from langchain_core.runnables import RunnableLambda, RunnablePassthrough | ||
|
||
def extract_prompt_value(v): | ||
if isinstance(v, list): | ||
return [extract_prompt_value(item) for item in v] | ||
elif isinstance(v, dict) and 'label' in v: | ||
return [v.get('label')] | ||
else: | ||
return v | ||
|
||
class HTTPResponse: | ||
def __init__(self, config): | ||
self.config = config | ||
self.store = {} | ||
|
||
def debug_response_passthrough(self): | ||
return RunnableLambda(lambda x: debug_response(self.config, x, self.original_question)) | ||
|
||
def original_question_passthrough(self): | ||
def get_and_send_original_question(docs): | ||
source_documents = [] | ||
for doc in docs["context"]: | ||
doc.metadata = {key: extract_prompt_value(doc.metadata.get(key)) for key in self.config.attributes if key in doc.metadata} | ||
source_document = doc.metadata.copy() | ||
source_document["content"] = doc.page_content | ||
source_documents.append(source_document) | ||
|
||
original_question = { | ||
"question": self.config.question, | ||
"source_documents": source_documents, | ||
} | ||
|
||
self.original_question = original_question | ||
return docs | ||
|
||
return RunnablePassthrough(get_and_send_original_question) | ||
|
||
def prepare_response(self): | ||
try: | ||
retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "size": self.config.size, "_source": {"excludes": ["embedding"]}}) | ||
chain = ( | ||
{"context": retriever, "question": RunnablePassthrough()} | ||
| self.original_question_passthrough() | ||
| self.config.prompt | ||
| self.config.client | ||
| StrOutputParser() | ||
| self.debug_response_passthrough() | ||
) | ||
response = chain.invoke(self.config.question) | ||
except Exception as err: | ||
response = { | ||
"question": self.config.question, | ||
"error": str(err), | ||
"source_documents": [], | ||
} | ||
return response | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
import os | ||
import json | ||
|
||
from dataclasses import dataclass, field | ||
|
||
from langchain_core.prompts import ChatPromptTemplate | ||
from setup import ( | ||
opensearch_client, | ||
opensearch_vector_store, | ||
openai_chat_client, | ||
) | ||
from typing import List | ||
from helpers.apitoken import ApiToken | ||
from helpers.prompts import document_template, prompt_template | ||
|
||
CHAIN_TYPE = "stuff" | ||
DOCUMENT_VARIABLE_NAME = "context" | ||
K_VALUE = 40 | ||
MAX_K = 100 | ||
MAX_TOKENS = 1000 | ||
SIZE = 5 | ||
TEMPERATURE = 0.2 | ||
TEXT_KEY = "id" | ||
VERSION = "2024-02-01" | ||
|
||
@dataclass | ||
class HTTPEventConfig: | ||
""" | ||
The EventConfig class represents the configuration for an event. | ||
Default values are set for the following properties which can be overridden in the payload message. | ||
""" | ||
|
||
DEFAULT_ATTRIBUTES = ["accession_number", "alternate_title", "api_link", "canonical_link", "caption", "collection", | ||
"contributor", "date_created", "date_created_edtf", "description", "genre", "id", "identifier", | ||
"keywords", "language", "notes", "physical_description_material", "physical_description_size", | ||
"provenance", "publisher", "rights_statement", "subject", "table_of_contents", "thumbnail", | ||
"title", "visibility", "work_type"] | ||
|
||
api_token: ApiToken = field(init=False) | ||
attributes: List[str] = field(init=False) | ||
azure_endpoint: str = field(init=False) | ||
azure_resource_name: str = field(init=False) | ||
debug_mode: bool = field(init=False) | ||
deployment_name: str = field(init=False) | ||
document_prompt: ChatPromptTemplate = field(init=False) | ||
event: dict = field(default_factory=dict) | ||
is_logged_in: bool = field(init=False) | ||
k: int = field(init=False) | ||
max_tokens: int = field(init=False) | ||
openai_api_version: str = field(init=False) | ||
payload: dict = field(default_factory=dict) | ||
prompt_text: str = field(init=False) | ||
prompt: ChatPromptTemplate = field(init=False) | ||
question: str = field(init=False) | ||
ref: str = field(init=False) | ||
request_context: dict = field(init=False) | ||
temperature: float = field(init=False) | ||
size: int = field(init=False) | ||
stream_response: bool = field(init=False) | ||
text_key: str = field(init=False) | ||
|
||
def __post_init__(self): | ||
self.payload = json.loads(self.event.get("body", "{}")) | ||
self.api_token = ApiToken(signed_token=self.payload.get("auth")) | ||
self.attributes = self._get_attributes() | ||
self.azure_endpoint = self._get_azure_endpoint() | ||
self.azure_resource_name = self._get_azure_resource_name() | ||
self.debug_mode = self._is_debug_mode_enabled() | ||
self.deployment_name = self._get_deployment_name() | ||
self.is_logged_in = self.api_token.is_logged_in() | ||
self.k = self._get_k() | ||
self.max_tokens = min(self.payload.get("max_tokens", MAX_TOKENS), MAX_TOKENS) | ||
self.openai_api_version = self._get_openai_api_version() | ||
self.prompt_text = self._get_prompt_text() | ||
self.request_context = self.event.get("requestContext", {}) | ||
self.question = self.payload.get("question") | ||
self.ref = self.payload.get("ref") | ||
self.size = self._get_size() | ||
self.stream_response = self.payload.get("stream_response", not self.debug_mode) | ||
self.temperature = self._get_temperature() | ||
self.text_key = self._get_text_key() | ||
self.document_prompt = self._get_document_prompt() | ||
self.prompt = ChatPromptTemplate.from_template(self.prompt_text) | ||
|
||
def _get_payload_value_with_superuser_check(self, key, default): | ||
if self.api_token.is_superuser(): | ||
return self.payload.get(key, default) | ||
else: | ||
return default | ||
|
||
def _get_attributes_function(self): | ||
try: | ||
opensearch = opensearch_client() | ||
mapping = opensearch.indices.get_mapping(index="dc-v2-work") | ||
return list(next(iter(mapping.values()))['mappings']['properties'].keys()) | ||
except StopIteration: | ||
return [] | ||
|
||
def _get_attributes(self): | ||
return self._get_payload_value_with_superuser_check("attributes", self.DEFAULT_ATTRIBUTES) | ||
|
||
def _get_azure_endpoint(self): | ||
default = f"https://{self._get_azure_resource_name()}.openai.azure.com/" | ||
return self._get_payload_value_with_superuser_check("azure_endpoint", default) | ||
|
||
def _get_azure_resource_name(self): | ||
azure_resource_name = self._get_payload_value_with_superuser_check( | ||
"azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME") | ||
) | ||
if not azure_resource_name: | ||
raise EnvironmentError( | ||
"Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set" | ||
) | ||
return azure_resource_name | ||
|
||
def _get_deployment_name(self): | ||
return self._get_payload_value_with_superuser_check( | ||
"deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") | ||
) | ||
|
||
def _get_k(self): | ||
value = self._get_payload_value_with_superuser_check("k", K_VALUE) | ||
return min(value, MAX_K) | ||
|
||
def _get_openai_api_version(self): | ||
return self._get_payload_value_with_superuser_check( | ||
"openai_api_version", VERSION | ||
) | ||
|
||
def _get_prompt_text(self): | ||
return self._get_payload_value_with_superuser_check("prompt", prompt_template()) | ||
|
||
def _get_size(self): | ||
return self._get_payload_value_with_superuser_check("size", SIZE) | ||
|
||
def _get_temperature(self): | ||
return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE) | ||
|
||
def _get_text_key(self): | ||
return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY) | ||
|
||
def _get_document_prompt(self): | ||
return ChatPromptTemplate.from_template(document_template(self.attributes)) | ||
|
||
def debug_message(self): | ||
return { | ||
"type": "debug", | ||
"message": { | ||
"attributes": self.attributes, | ||
"azure_endpoint": self.azure_endpoint, | ||
"deployment_name": self.deployment_name, | ||
"k": self.k, | ||
"openai_api_version": self.openai_api_version, | ||
"prompt": self.prompt_text, | ||
"question": self.question, | ||
"ref": self.ref, | ||
"size": self.ref, | ||
"temperature": self.temperature, | ||
"text_key": self.text_key, | ||
}, | ||
} | ||
|
||
def setup_llm_request(self): | ||
self._setup_vector_store() | ||
self._setup_chat_client() | ||
|
||
def _setup_vector_store(self): | ||
self.opensearch = opensearch_vector_store() | ||
|
||
def _setup_chat_client(self): | ||
self.client = openai_chat_client( | ||
azure_deployment=self.deployment_name, | ||
azure_endpoint=self.azure_endpoint, | ||
openai_api_version=self.openai_api_version, | ||
max_tokens=self.max_tokens | ||
) | ||
|
||
def _is_debug_mode_enabled(self): | ||
debug = self.payload.get("debug", False) | ||
return debug and self.api_token.is_superuser() | ||
|
||
def _to_bool(self, val): | ||
"""Converts a value to boolean. If the value is a string, it considers | ||
"", "no", "false", "0" as False. Otherwise, it returns the boolean of the value. | ||
""" | ||
if isinstance(val, str): | ||
return val.lower() not in ["", "no", "false", "0"] | ||
return bool(val) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# ruff: noqa: E402 | ||
|
||
import os | ||
import sys | ||
|
||
sys.path.append('./src') | ||
|
||
from unittest import mock, TestCase | ||
from unittest.mock import patch | ||
from handlers.chat_sync import handler | ||
from helpers.apitoken import ApiToken | ||
|
||
class MockContext: | ||
def __init__(self): | ||
self.log_stream_name = 'test' | ||
|
||
@mock.patch.dict( | ||
os.environ, | ||
{ | ||
"AZURE_OPENAI_RESOURCE_NAME": "test", | ||
}, | ||
) | ||
class TestHandler(TestCase): | ||
def test_handler_unauthorized(self): | ||
self.assertEqual(handler({"body": '{ "question": "Question?"}'}, MockContext()), {'body': 'Unauthorized', 'statusCode': 401}) | ||
|
||
@patch.object(ApiToken, 'is_logged_in') | ||
def test_no_question(self, mock_is_logged_in): | ||
mock_is_logged_in.return_value = True | ||
self.assertEqual(handler({"body": '{ "question": ""}'}, MockContext()), {'statusCode': 400, 'body': 'Question cannot be blank'}) | ||
|
||
@patch.object(ApiToken, 'is_logged_in') | ||
def test_handler_success(self, mock_is_logged_in): | ||
mock_is_logged_in.return_value = True | ||
self.assertEqual(handler({"body": '{"question": "Question?"}'}, MockContext()), {'statusCode': 200}) |
Oops, something went wrong.