Skip to content

Commit

Permalink
Provide CAII batch embedding for better performance (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
conradocloudera authored Dec 9, 2024
1 parent a5cc800 commit 2dac585
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion llm-service/app/services/CaiiEmbeddingModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import http.client as http_client
import json
import os
from typing import Any, Dict
from typing import Any, Dict, List

from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
from pydantic import Field
Expand Down Expand Up @@ -86,6 +86,33 @@ def _get_embedding(self, query: str, input_type: str) -> Embedding:

return embedding

def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
model = self.endpoint["endpointmetadata"]["model_name"]
domain = os.environ["CAII_DOMAIN"]

connection = http_client.HTTPSConnection(domain, 443)
headers = self.build_auth_headers()
headers["Content-Type"] = "application/json"
body = json.dumps(
{
"input": texts,
"input_type": "passage",
"truncate": "END",
"model": model,
}
)
connection.request("POST", self.endpoint["url"], body=body, headers=headers)
res = connection.getresponse()
data = res.read()
json_response = data.decode("utf-8")
structured_response = json.loads(json_response)
embeddings = structured_response["data"][0]["embedding"]
assert isinstance(embeddings, list)
assert all(isinstance(x, list) for x in embeddings)
assert all(all(isinstance(y, float) for y in x) for x in embeddings)

return embeddings

def build_auth_headers(self) -> Dict[str, str]:
with open("/tmp/jwt", "r") as file:
jwt_contents = json.load(file)
Expand Down

0 comments on commit 2dac585

Please sign in to comment.