Skip to content

Commit

Permalink
Merge pull request #237 from nulib/5151-increase-k
Browse files Browse the repository at this point in the history
Chat: Up default k to 40 but leave size at 5
  • Loading branch information
bmquinn authored Aug 8, 2024
2 parents 9a85e51 + 89e7c5b commit 92526c2
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 7 deletions.
9 changes: 8 additions & 1 deletion chat/src/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

CHAIN_TYPE = "stuff"
DOCUMENT_VARIABLE_NAME = "context"
K_VALUE = 5
K_VALUE = 40
MAX_K = 100
MAX_TOKENS = 1000
SIZE = 5
TEMPERATURE = 0.2
TEXT_KEY = "id"
VERSION = "2024-02-01"
Expand Down Expand Up @@ -56,6 +57,7 @@ class EventConfig:
ref: str = field(init=False)
request_context: dict = field(init=False)
temperature: float = field(init=False)
size: int = field(init=False)
socket: Websocket = field(init=False, default=None)
stream_response: bool = field(init=False)
text_key: str = field(init=False)
Expand All @@ -76,6 +78,7 @@ def __post_init__(self):
self.request_context = self.event.get("requestContext", {})
self.question = self.payload.get("question")
self.ref = self.payload.get("ref")
self.size = self._get_size()
self.stream_response = self.payload.get("stream_response", not self.debug_mode)
self.temperature = self._get_temperature()
self.text_key = self._get_text_key()
Expand Down Expand Up @@ -130,6 +133,9 @@ def _get_openai_api_version(self):
def _get_prompt_text(self):
return self._get_payload_value_with_superuser_check("prompt", prompt_template())

def _get_size(self):
return self._get_payload_value_with_superuser_check("size", SIZE)

def _get_temperature(self):
return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE)

Expand All @@ -151,6 +157,7 @@ def debug_message(self):
"prompt": self.prompt_text,
"question": self.question,
"ref": self.ref,
"size": self.ref,
"temperature": self.temperature,
"text_key": self.text_key,
},
Expand Down
2 changes: 1 addition & 1 deletion chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
RESPONSE_TYPES = {
"base": ["answer", "ref"],
"debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"],
"log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "source_documents", "temperature", "token_counts"]
"log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "size", "source_documents", "temperature", "token_counts"]
}

def handler(event, context):
Expand Down
2 changes: 1 addition & 1 deletion chat/src/helpers/hybrid_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def filter(query: dict):

def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k: int = 10, **kwargs: Any):
result = {
"size": k,
"size": kwargs.get("size", 5),
"query": {
"hybrid": {
"queries": [
Expand Down
1 change: 1 addition & 0 deletions chat/src/helpers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def debug_response(config, response, original_question):
"prompt": config.prompt_text,
"question": config.question,
"ref": config.ref,
"size": config.size,
"source_documents": source_urls,
"temperature": config.temperature,
"text_key": config.text_key,
Expand Down
2 changes: 1 addition & 1 deletion chat/src/helpers/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_and_send_original_question(docs):

def prepare_response(self):
try:
retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "_source": {"excludes": ["embedding"]}})
retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "size": self.config.size, "_source": {"excludes": ["embedding"]}})
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| self.original_question_passthrough()
Expand Down
6 changes: 4 additions & 2 deletions chat/test/helpers/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ def setUp(self):
"body": json.dumps({
"deployment_name": "test",
"index": "test",
"k": 5,
"k": 40,
"openai_api_version": "2019-05-06",
"prompt": "This is a test prompt.",
"question": self.question,
"ref": "test",
"size": 5,
"temperature": 0.5,
"text_key": "text",
"auth": "test123"
Expand All @@ -78,9 +79,10 @@ def setUp(self):
def test_debug_response(self):
result = debug_response(self.config, self.response, self.original_question)

self.assertEqual(result["k"], 5)
self.assertEqual(result["k"], 40)
self.assertEqual(result["question"], self.question)
self.assertEqual(result["ref"], "test")
self.assertEqual(result["size"], 5)
self.assertEqual(
result["source_documents"],
[
Expand Down
4 changes: 3 additions & 1 deletion chat/test/test_event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_attempt_override_without_superuser_status(self):
"openai_api_version": "2024-01-01",
"question": "test question",
"ref": "test ref",
"size": 90,
"temperature": 0.9,
"text_key": "accession_number",
}
Expand All @@ -52,9 +53,10 @@ def test_attempt_override_without_superuser_status(self):
expected_output = {
"attributes": EventConfig.DEFAULT_ATTRIBUTES,
"azure_endpoint": "https://test.openai.azure.com/",
"k": 5,
"k": 40,
"openai_api_version": "2024-02-01",
"question": "test question",
"size": 5,
"ref": "test ref",
"temperature": 0.2,
"text_key": "id",
Expand Down

0 comments on commit 92526c2

Please sign in to comment.