diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py index 5ec3381e..ff4c0d0c 100644 --- a/semantic_router/encoders/bedrock.py +++ b/semantic_router/encoders/bedrock.py @@ -17,7 +17,7 @@ """ import json -from typing import List, Optional, Any +from typing import Dict, List, Optional, Any, Union import os from time import sleep import tiktoken @@ -138,11 +138,12 @@ def _initialize_client( ) from err return bedrock_client - def __call__(self, docs: List[str]) -> List[List[float]]: + def __call__(self, docs: List[Union[str, Dict]], model_kwargs: Optional[Dict] = None) -> List[List[float]]: """Generates embeddings for the given documents. Args: docs: A list of strings representing the documents to embed. + model_kwargs: A dictionary of model-specific inference parameters. Returns: A list of lists, where each inner list contains the embedding values for a @@ -168,11 +169,25 @@ def __call__(self, docs: List[str]) -> List[List[float]]: embeddings = [] if self.name and "amazon" in self.name: for doc in docs: - embedding_body = json.dumps( - { - "inputText": doc, - } - ) + + embedding_body = {} + + if isinstance(doc, dict): + embedding_body['inputText'] = doc.get('text') + embedding_body['inputImage'] = doc.get('image') # expects a base64-encoded image + else: + embedding_body['inputText'] = doc + + # Add model-specific inference parameters + if model_kwargs: + embedding_body = embedding_body | model_kwargs + + # Clean up null values + embedding_body = {k: v for k, v in embedding_body.items() if v} + + # Format payload + embedding_body = json.dumps(embedding_body) + response = self.client.invoke_model( body=embedding_body, modelId=self.name, @@ -184,9 +199,19 @@ def __call__(self, docs: List[str]) -> List[List[float]]: elif self.name and "cohere" in self.name: chunked_docs = self.chunk_strings(docs) for chunk in chunked_docs: - chunk = json.dumps( - {"texts": chunk, "input_type": self.input_type} - ) + chunk = { + 'texts': chunk, + 'input_type': self.input_type + } + + # Add model-specific inference parameters + # Note: if specified, input_type will be overwritten by model_kwargs + if model_kwargs: + chunk = chunk | model_kwargs + + # Format payload + chunk = json.dumps(chunk) + response = self.client.invoke_model( body=chunk, modelId=self.name,