Skip to content

Commit

Permalink
Make OpenSearchNeuralSearch client injectable, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bmquinn committed Mar 11, 2024
1 parent ff29824 commit dddb6db
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 20 deletions.
9 changes: 4 additions & 5 deletions chat/src/handlers/opensearch_neural_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class OpenSearchNeuralSearch(VectorStore):

def __init__(
self,
client: None,
endpoint: str,
index: str,
model_id: str,
Expand All @@ -17,7 +18,7 @@ def __init__(
text_field: str = "id",
**kwargs: Any,
):
self.client = OpenSearch(
self.client = client or OpenSearch(
hosts=[{"host": endpoint, "port": "443", "use_ssl": True}], **kwargs
)
self.index = index
Expand Down Expand Up @@ -57,15 +58,13 @@ def similarity_search_with_score(
}
},
}

if subquery:
dsl["query"]["hybrid"]["queries"].append(subquery)
dsl["query"]["hybrid"]["queries"].append(subquery)

for key, value in kwargs.items():
dsl[key] = value

print(f"OpenSearchNeuralSearch dsl: {dsl}")

response = self.client.search(index=self.index, body=dsl, params={"search_pipeline": self.search_pipeline} if self.search_pipeline else None)

documents_with_scores = [
Expand Down
30 changes: 15 additions & 15 deletions chat/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -172,27 +172,27 @@ Resources:
Action: lambda:InvokeFunction
FunctionName: !Ref ChatFunction
Principal: apigateway.amazonaws.com
# ChatDependencies:
# Type: AWS::Serverless::LayerVersion
# Properties:
# LayerName:
# Fn::Sub: "${AWS::StackName}-dependencies"
# Description: Dependencies for streaming chat function
# ContentUri: ./dependencies
# CompatibleRuntimes:
# - python3.10
# LicenseInfo: "Apache-2.0"
# Metadata:
# BuildMethod: python3.10
ChatDependencies:
Type: AWS::Serverless::LayerVersion
Properties:
LayerName:
Fn::Sub: "${AWS::StackName}-dependencies"
Description: Dependencies for streaming chat function
ContentUri: ./dependencies
CompatibleRuntimes:
- python3.10
LicenseInfo: "Apache-2.0"
Metadata:
BuildMethod: python3.10
ChatFunction:
Type: AWS::Serverless::Function
Properties:
CodeUri: ./src
Runtime: python3.10
Architectures:
- x86_64
# Layers:
# - !Ref ChatDependencies
Layers:
- !Ref ChatDependencies
MemorySize: 1024
Handler: handlers/chat.handler
Timeout: 300
Expand All @@ -219,7 +219,7 @@ Resources:
- 'es:ESHttpPost'
Resource: '*'
Metadata:
BuildMethod: python3.10
BuildMethod: nodejs18.x
Deployment:
Type: AWS::ApiGatewayV2::Deployment
DependsOn:
Expand Down
43 changes: 43 additions & 0 deletions chat/test/handlers/test_opensearch_neural_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# ruff: noqa: E402
import sys
sys.path.append('./src')

from unittest import TestCase
from handlers.opensearch_neural_search import OpenSearchNeuralSearch
from langchain_core.documents import Document

class MockClient():
def search(self, index, body, params):
return {
"hits": {
"hits": [
{
"_source": {
"id": "test"
},
"_score": 0.12345
}
]
}
}

class TestOpenSearchNeuralSearch(TestCase):
def test_similarity_search(self):
docs = OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").similarity_search(query="test", subquery={"_source": {"excludes": ["embedding"]}}, size=10)
self.assertEqual(docs, [Document(page_content='test', metadata={'id': 'test'})])

def test_similarity_search_with_score(self):
docs = OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").similarity_search_with_score(query="test")
self.assertEqual(docs, [(Document(page_content='test', metadata={'id': 'test'}), 0.12345)])

def test_add_texts(self):
try:
OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").add_texts(texts=["test"], metadatas=[{"id": "test"}])
except Exception as e:
self.fail(f"from_texts raised an exception: {e}")

def test_from_texts(self):
try:
OpenSearchNeuralSearch.from_texts(clas="test", texts=["test"], metadatas=[{"id": "test"}])
except Exception as e:
self.fail(f"from_texts raised an exception: {e}")

0 comments on commit dddb6db

Please sign in to comment.