-
Notifications
You must be signed in to change notification settings - Fork 0
/
sagemaker.py
82 lines (61 loc) · 2.76 KB
/
sagemaker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# mypy: ignore-errors
import json
from typing import Any
import boto3
from llama_index.embeddings.base import BaseEmbedding
from pydantic import Field, PrivateAttr
class SagemakerEmbedding(BaseEmbedding):
"""Sagemaker Embedding Endpoint.
To use, you must supply the endpoint name from your deployed
Sagemaker embedding model & the region where it is deployed.
To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass
the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to
access the Sagemaker endpoint.
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
"""
endpoint_name: str = Field(description="")
_boto_client: Any = boto3.client(
"sagemaker-runtime",
) # TODO make it an optional field
_async_not_implemented_warned: bool = PrivateAttr(default=False)
@classmethod
def class_name(cls) -> str:
return "SagemakerEmbedding"
def _async_not_implemented_warn_once(self) -> None:
if not self._async_not_implemented_warned:
print("Async embedding not available, falling back to sync method.")
self._async_not_implemented_warned = True
def _embed(self, sentences: list[str]) -> list[list[float]]:
request_params = {
"inputs": sentences,
}
resp = self._boto_client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=json.dumps(request_params),
ContentType="application/json",
)
response_body = resp["Body"]
response_str = response_body.read().decode("utf-8")
response_json = json.loads(response_str)
return response_json["vectors"]
def _get_query_embedding(self, query: str) -> list[float]:
"""Get query embedding."""
return self._embed([query])[0]
async def _aget_query_embedding(self, query: str) -> list[float]:
# Warn the user that sync is being used
self._async_not_implemented_warn_once()
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> list[float]:
# Warn the user that sync is being used
self._async_not_implemented_warn_once()
return self._get_text_embedding(text)
def _get_text_embedding(self, text: str) -> list[float]:
"""Get text embedding."""
return self._embed([text])[0]
def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]:
"""Get text embeddings."""
return self._embed(texts)