From dddb6dbd8ee1e0464e1995599e3d0e59ecfa9450 Mon Sep 17 00:00:00 2001 From: Brendan Quinn Date: Mon, 11 Mar 2024 21:52:03 +0000 Subject: [PATCH] Make OpenSearchNeuralSearch client injectable, add tests --- chat/src/handlers/opensearch_neural_search.py | 9 ++-- chat/template.yaml | 30 ++++++------- .../handlers/test_opensearch_neural_search.py | 43 +++++++++++++++++++ 3 files changed, 62 insertions(+), 20 deletions(-) create mode 100644 chat/test/handlers/test_opensearch_neural_search.py diff --git a/chat/src/handlers/opensearch_neural_search.py b/chat/src/handlers/opensearch_neural_search.py index 5c7e0011..09b59cb2 100644 --- a/chat/src/handlers/opensearch_neural_search.py +++ b/chat/src/handlers/opensearch_neural_search.py @@ -9,6 +9,7 @@ class OpenSearchNeuralSearch(VectorStore): def __init__( self, + client: None, endpoint: str, index: str, model_id: str, @@ -17,7 +18,7 @@ def __init__( text_field: str = "id", **kwargs: Any, ): - self.client = OpenSearch( + self.client = client or OpenSearch( hosts=[{"host": endpoint, "port": "443", "use_ssl": True}], **kwargs ) self.index = index @@ -57,15 +58,13 @@ def similarity_search_with_score( } }, } - + if subquery: - dsl["query"]["hybrid"]["queries"].append(subquery) + dsl["query"]["hybrid"]["queries"].append(subquery) for key, value in kwargs.items(): dsl[key] = value - print(f"OpenSearchNeuralSearch dsl: {dsl}") - response = self.client.search(index=self.index, body=dsl, params={"search_pipeline": self.search_pipeline} if self.search_pipeline else None) documents_with_scores = [ diff --git a/chat/template.yaml b/chat/template.yaml index 30307d17..24b95aac 100644 --- a/chat/template.yaml +++ b/chat/template.yaml @@ -172,18 +172,18 @@ Resources: Action: lambda:InvokeFunction FunctionName: !Ref ChatFunction Principal: apigateway.amazonaws.com - # ChatDependencies: - # Type: AWS::Serverless::LayerVersion - # Properties: - # LayerName: - # Fn::Sub: "${AWS::StackName}-dependencies" - # Description: Dependencies for streaming chat function - # ContentUri: ./dependencies - # CompatibleRuntimes: - # - python3.10 - # LicenseInfo: "Apache-2.0" - # Metadata: - # BuildMethod: python3.10 + ChatDependencies: + Type: AWS::Serverless::LayerVersion + Properties: + LayerName: + Fn::Sub: "${AWS::StackName}-dependencies" + Description: Dependencies for streaming chat function + ContentUri: ./dependencies + CompatibleRuntimes: + - python3.10 + LicenseInfo: "Apache-2.0" + Metadata: + BuildMethod: python3.10 ChatFunction: Type: AWS::Serverless::Function Properties: @@ -191,8 +191,8 @@ Resources: Runtime: python3.10 Architectures: - x86_64 - # Layers: - # - !Ref ChatDependencies + Layers: + - !Ref ChatDependencies MemorySize: 1024 Handler: handlers/chat.handler Timeout: 300 @@ -219,7 +219,7 @@ Resources: - 'es:ESHttpPost' Resource: '*' Metadata: - BuildMethod: python3.10 + BuildMethod: nodejs18.x Deployment: Type: AWS::ApiGatewayV2::Deployment DependsOn: diff --git a/chat/test/handlers/test_opensearch_neural_search.py b/chat/test/handlers/test_opensearch_neural_search.py new file mode 100644 index 00000000..d7448679 --- /dev/null +++ b/chat/test/handlers/test_opensearch_neural_search.py @@ -0,0 +1,43 @@ +# ruff: noqa: E402 +import sys +sys.path.append('./src') + +from unittest import TestCase +from handlers.opensearch_neural_search import OpenSearchNeuralSearch +from langchain_core.documents import Document + +class MockClient(): + def search(self, index, body, params): + return { + "hits": { + "hits": [ + { + "_source": { + "id": "test" + }, + "_score": 0.12345 + } + ] + } + } + +class TestOpenSearchNeuralSearch(TestCase): + def test_similarity_search(self): + docs = OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").similarity_search(query="test", subquery={"_source": {"excludes": ["embedding"]}}, size=10) + self.assertEqual(docs, [Document(page_content='test', metadata={'id': 'test'})]) + + def test_similarity_search_with_score(self): + docs = OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").similarity_search_with_score(query="test") + self.assertEqual(docs, [(Document(page_content='test', metadata={'id': 'test'}), 0.12345)]) + + def test_add_texts(self): + try: + OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").add_texts(texts=["test"], metadatas=[{"id": "test"}]) + except Exception as e: + self.fail(f"from_texts raised an exception: {e}") + + def test_from_texts(self): + try: + OpenSearchNeuralSearch.from_texts(clas="test", texts=["test"], metadatas=[{"id": "test"}]) + except Exception as e: + self.fail(f"from_texts raised an exception: {e}") \ No newline at end of file