diff --git a/chat/src/handlers/chat_sync.py b/chat/src/handlers/chat_sync.py new file mode 100644 index 00000000..8166870e --- /dev/null +++ b/chat/src/handlers/chat_sync.py @@ -0,0 +1,43 @@ + +import json +import logging +import os +from http_event_config import HTTPEventConfig +from helpers.http_response import HTTPResponse +from honeybadger import honeybadger + +honeybadger.configure() +logging.getLogger('honeybadger').addHandler(logging.StreamHandler()) + +RESPONSE_TYPES = { + "base": ["answer", "ref"], + "debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"], + "log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "size", "source_documents", "temperature", "token_counts"], + "error": ["question", "error", "source_documents"] +} + +def handler(event, context): + print(f'Event: {event}') + + config = HTTPEventConfig(event) + + if not config.is_logged_in: + return {"statusCode": 401, "body": "Unauthorized"} + + if config.question is None or config.question == "": + return {"statusCode": 400, "body": "Question cannot be blank"} + + if not os.getenv("SKIP_WEAVIATE_SETUP"): + config.setup_llm_request() + response = HTTPResponse(config) + final_response = response.prepare_response() + if "error" in final_response: + logging.error(f'Error: {final_response["error"]}') + return {"statusCode": 500, "body": "Internal Server Error"} + else: + return {"statusCode": 200, "body": json.dumps(reshape_response(final_response, 'debug' if config.debug_mode else 'base'))} + + return {"statusCode": 200} + +def reshape_response(response, type): + return {k: response[k] for k in RESPONSE_TYPES[type]} \ No newline at end of file diff --git a/chat/src/helpers/http_response.py b/chat/src/helpers/http_response.py new file mode 100644 index 00000000..fc6abc6f --- /dev/null +++ b/chat/src/helpers/http_response.py @@ -0,0 +1,60 @@ +from helpers.metrics import debug_response +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import RunnableLambda, RunnablePassthrough + +def extract_prompt_value(v): + if isinstance(v, list): + return [extract_prompt_value(item) for item in v] + elif isinstance(v, dict) and 'label' in v: + return [v.get('label')] + else: + return v + +class HTTPResponse: + def __init__(self, config): + self.config = config + self.store = {} + + def debug_response_passthrough(self): + return RunnableLambda(lambda x: debug_response(self.config, x, self.original_question)) + + def original_question_passthrough(self): + def get_and_send_original_question(docs): + source_documents = [] + for doc in docs["context"]: + doc.metadata = {key: extract_prompt_value(doc.metadata.get(key)) for key in self.config.attributes if key in doc.metadata} + source_document = doc.metadata.copy() + source_document["content"] = doc.page_content + source_documents.append(source_document) + + original_question = { + "question": self.config.question, + "source_documents": source_documents, + } + + self.original_question = original_question + return docs + + return RunnablePassthrough(get_and_send_original_question) + + def prepare_response(self): + try: + retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "size": self.config.size, "_source": {"excludes": ["embedding"]}}) + chain = ( + {"context": retriever, "question": RunnablePassthrough()} + | self.original_question_passthrough() + | self.config.prompt + | self.config.client + | StrOutputParser() + | self.debug_response_passthrough() + ) + response = chain.invoke(self.config.question) + except Exception as err: + response = { + "question": self.config.question, + "error": str(err), + "source_documents": [], + } + return response + + \ No newline at end of file diff --git a/chat/src/http_event_config.py b/chat/src/http_event_config.py new file mode 100644 index 00000000..47f479aa --- /dev/null +++ b/chat/src/http_event_config.py @@ -0,0 +1,188 @@ +import os +import json + +from dataclasses import dataclass, field + +from langchain_core.prompts import ChatPromptTemplate +from setup import ( + opensearch_client, + opensearch_vector_store, + openai_chat_client, +) +from typing import List +from helpers.apitoken import ApiToken +from helpers.prompts import document_template, prompt_template + +CHAIN_TYPE = "stuff" +DOCUMENT_VARIABLE_NAME = "context" +K_VALUE = 40 +MAX_K = 100 +MAX_TOKENS = 1000 +SIZE = 5 +TEMPERATURE = 0.2 +TEXT_KEY = "id" +VERSION = "2024-02-01" + +@dataclass +class HTTPEventConfig: + """ + The EventConfig class represents the configuration for an event. + Default values are set for the following properties which can be overridden in the payload message. + """ + + DEFAULT_ATTRIBUTES = ["accession_number", "alternate_title", "api_link", "canonical_link", "caption", "collection", + "contributor", "date_created", "date_created_edtf", "description", "genre", "id", "identifier", + "keywords", "language", "notes", "physical_description_material", "physical_description_size", + "provenance", "publisher", "rights_statement", "subject", "table_of_contents", "thumbnail", + "title", "visibility", "work_type"] + + api_token: ApiToken = field(init=False) + attributes: List[str] = field(init=False) + azure_endpoint: str = field(init=False) + azure_resource_name: str = field(init=False) + debug_mode: bool = field(init=False) + deployment_name: str = field(init=False) + document_prompt: ChatPromptTemplate = field(init=False) + event: dict = field(default_factory=dict) + is_logged_in: bool = field(init=False) + k: int = field(init=False) + max_tokens: int = field(init=False) + openai_api_version: str = field(init=False) + payload: dict = field(default_factory=dict) + prompt_text: str = field(init=False) + prompt: ChatPromptTemplate = field(init=False) + question: str = field(init=False) + ref: str = field(init=False) + request_context: dict = field(init=False) + temperature: float = field(init=False) + size: int = field(init=False) + stream_response: bool = field(init=False) + text_key: str = field(init=False) + + def __post_init__(self): + self.payload = json.loads(self.event.get("body", "{}")) + self.api_token = ApiToken(signed_token=self.payload.get("auth")) + self.attributes = self._get_attributes() + self.azure_endpoint = self._get_azure_endpoint() + self.azure_resource_name = self._get_azure_resource_name() + self.debug_mode = self._is_debug_mode_enabled() + self.deployment_name = self._get_deployment_name() + self.is_logged_in = self.api_token.is_logged_in() + self.k = self._get_k() + self.max_tokens = min(self.payload.get("max_tokens", MAX_TOKENS), MAX_TOKENS) + self.openai_api_version = self._get_openai_api_version() + self.prompt_text = self._get_prompt_text() + self.request_context = self.event.get("requestContext", {}) + self.question = self.payload.get("question") + self.ref = self.payload.get("ref") + self.size = self._get_size() + self.stream_response = self.payload.get("stream_response", not self.debug_mode) + self.temperature = self._get_temperature() + self.text_key = self._get_text_key() + self.document_prompt = self._get_document_prompt() + self.prompt = ChatPromptTemplate.from_template(self.prompt_text) + + def _get_payload_value_with_superuser_check(self, key, default): + if self.api_token.is_superuser(): + return self.payload.get(key, default) + else: + return default + + def _get_attributes_function(self): + try: + opensearch = opensearch_client() + mapping = opensearch.indices.get_mapping(index="dc-v2-work") + return list(next(iter(mapping.values()))['mappings']['properties'].keys()) + except StopIteration: + return [] + + def _get_attributes(self): + return self._get_payload_value_with_superuser_check("attributes", self.DEFAULT_ATTRIBUTES) + + def _get_azure_endpoint(self): + default = f"https://{self._get_azure_resource_name()}.openai.azure.com/" + return self._get_payload_value_with_superuser_check("azure_endpoint", default) + + def _get_azure_resource_name(self): + azure_resource_name = self._get_payload_value_with_superuser_check( + "azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME") + ) + if not azure_resource_name: + raise EnvironmentError( + "Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set" + ) + return azure_resource_name + + def _get_deployment_name(self): + return self._get_payload_value_with_superuser_check( + "deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") + ) + + def _get_k(self): + value = self._get_payload_value_with_superuser_check("k", K_VALUE) + return min(value, MAX_K) + + def _get_openai_api_version(self): + return self._get_payload_value_with_superuser_check( + "openai_api_version", VERSION + ) + + def _get_prompt_text(self): + return self._get_payload_value_with_superuser_check("prompt", prompt_template()) + + def _get_size(self): + return self._get_payload_value_with_superuser_check("size", SIZE) + + def _get_temperature(self): + return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE) + + def _get_text_key(self): + return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY) + + def _get_document_prompt(self): + return ChatPromptTemplate.from_template(document_template(self.attributes)) + + def debug_message(self): + return { + "type": "debug", + "message": { + "attributes": self.attributes, + "azure_endpoint": self.azure_endpoint, + "deployment_name": self.deployment_name, + "k": self.k, + "openai_api_version": self.openai_api_version, + "prompt": self.prompt_text, + "question": self.question, + "ref": self.ref, + "size": self.ref, + "temperature": self.temperature, + "text_key": self.text_key, + }, + } + + def setup_llm_request(self): + self._setup_vector_store() + self._setup_chat_client() + + def _setup_vector_store(self): + self.opensearch = opensearch_vector_store() + + def _setup_chat_client(self): + self.client = openai_chat_client( + azure_deployment=self.deployment_name, + azure_endpoint=self.azure_endpoint, + openai_api_version=self.openai_api_version, + max_tokens=self.max_tokens + ) + + def _is_debug_mode_enabled(self): + debug = self.payload.get("debug", False) + return debug and self.api_token.is_superuser() + + def _to_bool(self, val): + """Converts a value to boolean. If the value is a string, it considers + "", "no", "false", "0" as False. Otherwise, it returns the boolean of the value. + """ + if isinstance(val, str): + return val.lower() not in ["", "no", "false", "0"] + return bool(val) diff --git a/chat/template.yaml b/chat/template.yaml index 04d59379..98ac2ed7 100644 --- a/chat/template.yaml +++ b/chat/template.yaml @@ -242,6 +242,48 @@ Resources: Resource: !Sub "${ChatMetricsLog.Arn}:*" #* Metadata: #* BuildMethod: nodejs20.x + ChatSyncFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./src + Runtime: python3.10 + Architectures: + - x86_64 + #* Layers: + #* - !Ref ChatDependencies + MemorySize: 1024 + Handler: handlers/chat_sync.handler + Timeout: 300 + Environment: + Variables: + API_TOKEN_SECRET: !Ref ApiTokenSecret + AZURE_OPENAI_API_KEY: !Ref AzureOpenaiApiKey + AZURE_OPENAI_LLM_DEPLOYMENT_ID: !Ref AzureOpenaiLlmDeploymentId + AZURE_OPENAI_RESOURCE_NAME: !Ref AzureOpenaiResourceName + ENV_PREFIX: !Ref EnvironmentPrefix + HONEYBADGER_API_KEY: !Ref HoneybadgerApiKey + HONEYBADGER_ENVIRONMENT: !Ref HoneybadgerEnv + HONEYBADGER_REVISION: !Ref HoneybadgerRevision + METRICS_LOG_GROUP: !Ref ChatMetricsLog + OPENSEARCH_ENDPOINT: !Ref OpenSearchEndpoint + OPENSEARCH_MODEL_ID: !Ref OpenSearchModelId + FunctionUrlConfig: + AuthType: NONE + Policies: + - Statement: + - Effect: Allow + Action: + - 'es:ESHttpGet' + - 'es:ESHttpPost' + Resource: '*' + # - Statement: + # - Effect: Allow + # Action: + # - logs:CreateLogStream + # - logs:PutLogEvents + # Resource: !Sub "${ChatMetricsLog.Arn}:*" + #* Metadata: + #* BuildMethod: nodejs20.x ChatMetricsLog: Type: AWS::Logs::LogGroup Properties: diff --git a/chat/test/handlers/test_chat_sync.py b/chat/test/handlers/test_chat_sync.py new file mode 100644 index 00000000..773ebfe0 --- /dev/null +++ b/chat/test/handlers/test_chat_sync.py @@ -0,0 +1,35 @@ +# ruff: noqa: E402 + +import os +import sys + +sys.path.append('./src') + +from unittest import mock, TestCase +from unittest.mock import patch +from handlers.chat_sync import handler +from helpers.apitoken import ApiToken + +class MockContext: + def __init__(self): + self.log_stream_name = 'test' + +@mock.patch.dict( + os.environ, + { + "AZURE_OPENAI_RESOURCE_NAME": "test", + }, +) +class TestHandler(TestCase): + def test_handler_unauthorized(self): + self.assertEqual(handler({"body": '{ "question": "Question?"}'}, MockContext()), {'body': 'Unauthorized', 'statusCode': 401}) + + @patch.object(ApiToken, 'is_logged_in') + def test_no_question(self, mock_is_logged_in): + mock_is_logged_in.return_value = True + self.assertEqual(handler({"body": '{ "question": ""}'}, MockContext()), {'statusCode': 400, 'body': 'Question cannot be blank'}) + + @patch.object(ApiToken, 'is_logged_in') + def test_handler_success(self, mock_is_logged_in): + mock_is_logged_in.return_value = True + self.assertEqual(handler({"body": '{"question": "Question?"}'}, MockContext()), {'statusCode': 200}) diff --git a/chat/test/helpers/test_http_event_config.py b/chat/test/helpers/test_http_event_config.py new file mode 100644 index 00000000..3bc67075 --- /dev/null +++ b/chat/test/helpers/test_http_event_config.py @@ -0,0 +1,95 @@ +# ruff: noqa: E402 +import json +import os +import sys +sys.path.append('./src') + +from http_event_config import HTTPEventConfig +from unittest import TestCase, mock + + +class TestEventConfigWithoutAzureResource(TestCase): + def test_requires_an_azure_resource(self): + with self.assertRaises(EnvironmentError): + HTTPEventConfig() + + +@mock.patch.dict( + os.environ, + { + "AZURE_OPENAI_RESOURCE_NAME": "test", + }, +) +class TestHTTPEventConfig(TestCase): + def test_fetches_attributes_from_vector_database(self): + os.environ.pop("AZURE_OPENAI_RESOURCE_NAME", None) + with self.assertRaises(EnvironmentError): + HTTPEventConfig() + + def test_defaults(self): + actual = HTTPEventConfig(event={"body": json.dumps({"attributes": ["title"]})}) + expected_defaults = {"azure_endpoint": "https://test.openai.azure.com/"} + self.assertEqual(actual.azure_endpoint, expected_defaults["azure_endpoint"]) + + def test_attempt_override_without_superuser_status(self): + actual = HTTPEventConfig( + event={ + "body": json.dumps( + { + "azure_resource_name": "new_name_for_test", + "attributes": ["title", "subject", "date_created"], + "index": "testIndex", + "k": 100, + "openai_api_version": "2024-01-01", + "question": "test question", + "ref": "test ref", + "size": 90, + "temperature": 0.9, + "text_key": "accession_number", + } + ) + } + ) + expected_output = { + "attributes": HTTPEventConfig.DEFAULT_ATTRIBUTES, + "azure_endpoint": "https://test.openai.azure.com/", + "k": 40, + "openai_api_version": "2024-02-01", + "question": "test question", + "size": 5, + "ref": "test ref", + "temperature": 0.2, + "text_key": "id", + } + self.assertEqual(actual.azure_endpoint, expected_output["azure_endpoint"]) + self.assertEqual(actual.attributes, expected_output["attributes"]) + self.assertEqual(actual.k, expected_output["k"]) + self.assertEqual( + actual.openai_api_version, expected_output["openai_api_version"] + ) + self.assertEqual(actual.question, expected_output["question"]) + self.assertEqual(actual.ref, expected_output["ref"]) + self.assertEqual(actual.temperature, expected_output["temperature"]) + self.assertEqual(actual.text_key, expected_output["text_key"]) + + def test_debug_message(self): + self.assertEqual( + HTTPEventConfig( + event={"body": json.dumps({"attributes": ["source"]})} + ).debug_message()["type"], + "debug", + ) + + def test_to_bool(self): + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool(""), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("0"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("no"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("false"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("False"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("FALSE"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("no"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("No"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("NO"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("true"), True) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool(True), True) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool(False), False)