diff --git a/docs/source/conf.py b/docs/source/conf.py index 23154157..6ad39408 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ project = "Semantic Router" copyright = "2024, Aurelio AI" author = "Aurelio AI" -release = "0.0.71" +release = "0.0.72" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/poetry.lock b/poetry.lock index 0c6f20e6..11dbb7ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -327,7 +327,7 @@ uvloop = ["uvloop (>=0.15.2)"] name = "boto3" version = "1.35.32" description = "The AWS SDK for Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "boto3-1.35.32-py3-none-any.whl", hash = "sha256:786a243f4b4827c6ae149442bf544c2ae449570cf23616a5d386f7a2633e0e08"}, @@ -346,7 +346,7 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] name = "botocore" version = "1.35.32" description = "Low-level, data-driven core of boto 3." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "botocore-1.35.32-py3-none-any.whl", hash = "sha256:2c0c2b62dd156daf904525f3f523ae22bf34ac109d727acf0bbfbca291440fc3"}, @@ -582,7 +582,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "cohere" version = "5.10.0" description = "" -optional = false +optional = true python-versions = "<4.0,>=3.8" files = [ {file = "cohere-5.10.0-py3-none-any.whl", hash = "sha256:46e50e3e8514a99cf77b4c022c8077a6205fba948051c33087ddeb66ec706f0a"}, @@ -982,7 +982,7 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth name = "fastavro" version = "1.9.7" description = "Fast read/write of AVRO files" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "fastavro-1.9.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cc811fb4f7b5ae95f969cda910241ceacf82e53014c7c7224df6f6e0ca97f52f"}, @@ -1056,7 +1056,7 @@ tqdm = ">=4.66,<5.0" name = "filelock" version = "3.16.1" description = "A platform independent file lock." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, @@ -1240,7 +1240,7 @@ files = [ name = "fsspec" version = "2024.9.0" description = "File-system specification" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b"}, @@ -1796,7 +1796,7 @@ socks = ["socksio (==1.*)"] name = "httpx-sse" version = "0.4.0" description = "Consume Server-Sent Event (SSE) messages with HTTPX." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"}, @@ -1807,7 +1807,7 @@ files = [ name = "huggingface-hub" version = "0.25.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" -optional = false +optional = true python-versions = ">=3.8.0" files = [ {file = "huggingface_hub-0.25.1-py3-none-any.whl", hash = "sha256:a5158ded931b3188f54ea9028097312cb0acd50bffaaa2612014c3c526b44972"}, @@ -2123,7 +2123,7 @@ files = [ name = "jmespath" version = "1.0.1" description = "JSON Matching Expressions" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, @@ -3211,7 +3211,7 @@ files = [ name = "parameterized" version = "0.9.0" description = "Parameterized testing with any Python test framework" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "parameterized-0.9.0-py2.py3-none-any.whl", hash = "sha256:4e0758e3d41bea3bbd05ec14fc2c24736723f243b28d702081aef438c9372b1b"}, @@ -4347,7 +4347,7 @@ files = [ name = "s3transfer" version = "0.10.2" description = "An Amazon S3 Transfer Manager" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"}, @@ -4862,7 +4862,7 @@ files = [ name = "tokenizers" version = "0.20.0" description = "" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "tokenizers-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6cff5c5e37c41bc5faa519d6f3df0679e4b37da54ea1f42121719c5e2b4905c0"}, @@ -5488,6 +5488,7 @@ type = ["pytest-mypy"] [extras] bedrock = ["boto3", "botocore"] +cohere = ["cohere"] docs = ["sphinx", "sphinxawesome-theme"] fastembed = ["fastembed"] google = ["google-cloud-aiplatform"] @@ -5503,4 +5504,4 @@ vision = ["pillow", "torch", "torchvision", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "3b6d8cef3e0d6c516a9d9704350e8ff6dac7277cabed851f8c4ccc84214df6ea" +content-hash = "b0ddd77f2b9a210601eba56f69630eaa6a53cb358cae95bace2ae080d51c7812" diff --git a/pyproject.toml b/pyproject.toml index 99f91f2d..edeb6578 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "semantic-router" -version = "0.0.71" +version = "0.0.72" description = "Super fast semantic router for AI decision making" authors = ["Aurelio AI "] readme = "README.md" @@ -11,7 +11,7 @@ license = "MIT" python = ">=3.9,<3.13" pydantic = "^2.5.3" openai = ">=1.10.0,<2.0.0" -cohere = ">=5.9.4,<6.00" +cohere = {version = ">=5.9.4,<6.00", optional = true} mistralai= {version = ">=0.0.12,<0.1.0", optional = true} numpy = "^1.25.2" colorlog = "^6.8.0" @@ -52,6 +52,7 @@ bedrock = ["boto3", "botocore"] postgres = ["psycopg2"] fastembed = ["fastembed"] docs = ["sphinx", "sphinxawesome-theme"] +cohere = ["cohere"] [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.0" diff --git a/semantic_router/__init__.py b/semantic_router/__init__.py index 0c9bf249..5ce4b911 100644 --- a/semantic_router/__init__.py +++ b/semantic_router/__init__.py @@ -4,4 +4,4 @@ __all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"] -__version__ = "0.0.71" +__version__ = "0.0.72" diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index 01426e9f..cdc114bb 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -1,15 +1,15 @@ import os -from typing import List, Optional +from typing import Any, List, Optional -import cohere -from cohere.types.embed_response import EmbeddingsByTypeEmbedResponse +from pydantic.v1 import PrivateAttr from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault class CohereEncoder(BaseEncoder): - client: Optional[cohere.Client] = None + _client: Any = PrivateAttr() + _embed_type: Any = PrivateAttr() type: str = "cohere" input_type: Optional[str] = "search_query" @@ -28,25 +28,47 @@ def __init__( input_type=input_type, # type: ignore ) self.input_type = input_type + self._client = self._initialize_client(cohere_api_key) + + def _initialize_client(self, cohere_api_key: Optional[str] = None): + """Initializes the Cohere client. + + :param cohere_api_key: The API key for the Cohere client, can also + be set via the COHERE_API_KEY environment variable. + + :return: An instance of the Cohere client. + """ + try: + import cohere + from cohere.types.embed_response import EmbeddingsByTypeEmbedResponse + + self._embed_type = EmbeddingsByTypeEmbedResponse + except ImportError: + raise ImportError( + "Please install Cohere to use CohereEncoder. " + "You can install it with: " + "`pip install 'semantic-router[cohere]'`" + ) cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") if cohere_api_key is None: raise ValueError("Cohere API key cannot be 'None'.") try: - self.client = cohere.Client(cohere_api_key) + client = cohere.Client(cohere_api_key) except Exception as e: raise ValueError( f"Cohere API client failed to initialize. Error: {e}" ) from e + return client def __call__(self, docs: List[str]) -> List[List[float]]: - if self.client is None: + if self._client is None: raise ValueError("Cohere client is not initialized.") try: - embeds = self.client.embed( + embeds = self._client.embed( texts=docs, input_type=self.input_type, model=self.name ) # Check for unsupported type. - if isinstance(embeds, EmbeddingsByTypeEmbedResponse): + if isinstance(embeds, self._embed_type): raise NotImplementedError( "Handling of EmbedByTypeResponseEmbeddings is not implemented." ) diff --git a/semantic_router/llms/cohere.py b/semantic_router/llms/cohere.py index 37eb4338..05a9b1bd 100644 --- a/semantic_router/llms/cohere.py +++ b/semantic_router/llms/cohere.py @@ -1,14 +1,14 @@ import os -from typing import List, Optional +from typing import Any, List, Optional -import cohere +from pydantic.v1 import PrivateAttr from semantic_router.llms import BaseLLM from semantic_router.schema import Message class CohereLLM(BaseLLM): - client: Optional[cohere.Client] = None + _client: Any = PrivateAttr() def __init__( self, @@ -18,21 +18,33 @@ def __init__( if name is None: name = os.getenv("COHERE_CHAT_MODEL_NAME", "command") super().__init__(name=name) + self._client = self._initialize_client(cohere_api_key) + + def _initialize_client(self, cohere_api_key: Optional[str] = None): + try: + import cohere + except ImportError: + raise ImportError( + "Please install Cohere to use CohereLLM. " + "You can install it with: " + "`pip install 'semantic-router[cohere]'`" + ) cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") if cohere_api_key is None: raise ValueError("Cohere API key cannot be 'None'.") try: - self.client = cohere.Client(cohere_api_key) + client = cohere.Client(cohere_api_key) except Exception as e: raise ValueError( f"Cohere API client failed to initialize. Error: {e}" ) from e + return client def __call__(self, messages: List[Message]) -> str: - if self.client is None: + if self._client is None: raise ValueError("Cohere client is not initialized.") try: - completion = self.client.chat( + completion = self._client.chat( model=self.name, chat_history=[m.to_cohere() for m in messages[:-1]], message=messages[-1].content, diff --git a/tests/unit/encoders/test_cohere.py b/tests/unit/encoders/test_cohere.py index 0f7607af..b4d81d24 100644 --- a/tests/unit/encoders/test_cohere.py +++ b/tests/unit/encoders/test_cohere.py @@ -11,7 +11,7 @@ def cohere_encoder(mocker): class TestCohereEncoder: def test_initialization_with_api_key(self, cohere_encoder): - assert cohere_encoder.client is not None, "Client should be initialized" + assert cohere_encoder._client is not None, "Client should be initialized" assert ( cohere_encoder.name == "embed-english-v3.0" ), "Default name not set correctly" @@ -25,38 +25,38 @@ def test_initialization_without_api_key(self, mocker, monkeypatch): def test_call_method(self, cohere_encoder, mocker): mock_embed = mocker.MagicMock() mock_embed.embeddings = [[0.1, 0.2, 0.3]] - cohere_encoder.client.embed.return_value = mock_embed + cohere_encoder._client.embed.return_value = mock_embed result = cohere_encoder(["test"]) assert isinstance(result, list), "Result should be a list" assert all( isinstance(sublist, list) for sublist in result ), "Each item in result should be a list" - cohere_encoder.client.embed.assert_called_once() + cohere_encoder._client.embed.assert_called_once() def test_returns_list_of_embeddings_for_valid_input(self, cohere_encoder, mocker): mock_embed = mocker.MagicMock() mock_embed.embeddings = [[0.1, 0.2, 0.3]] - cohere_encoder.client.embed.return_value = mock_embed + cohere_encoder._client.embed.return_value = mock_embed result = cohere_encoder(["test"]) assert isinstance(result, list), "Result should be a list" assert all( isinstance(sublist, list) for sublist in result ), "Each item in result should be a list" - cohere_encoder.client.embed.assert_called_once() + cohere_encoder._client.embed.assert_called_once() def test_handles_multiple_inputs_correctly(self, cohere_encoder, mocker): mock_embed = mocker.MagicMock() mock_embed.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - cohere_encoder.client.embed.return_value = mock_embed + cohere_encoder._client.embed.return_value = mock_embed result = cohere_encoder(["test1", "test2"]) assert isinstance(result, list), "Result should be a list" assert all( isinstance(sublist, list) for sublist in result ), "Each item in result should be a list" - cohere_encoder.client.embed.assert_called_once() + cohere_encoder._client.embed.assert_called_once() def test_raises_value_error_if_api_key_is_none(self, mocker, monkeypatch): monkeypatch.delenv("COHERE_API_KEY", raising=False) @@ -79,7 +79,7 @@ def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker): def test_call_method_raises_error_on_api_failure(self, cohere_encoder, mocker): mocker.patch.object( - cohere_encoder.client, "embed", side_effect=Exception("API call failed") + cohere_encoder._client, "embed", side_effect=Exception("API call failed") ) with pytest.raises(ValueError): cohere_encoder(["test"]) diff --git a/tests/unit/llms/test_llm_cohere.py b/tests/unit/llms/test_llm_cohere.py index aaf8a7e5..dc72931c 100644 --- a/tests/unit/llms/test_llm_cohere.py +++ b/tests/unit/llms/test_llm_cohere.py @@ -12,7 +12,7 @@ def cohere_llm(mocker): class TestCohereLLM: def test_initialization_with_api_key(self, cohere_llm): - assert cohere_llm.client is not None, "Client should be initialized" + assert cohere_llm._client is not None, "Client should be initialized" assert cohere_llm.name == "command", "Default name not set correctly" def test_initialization_without_api_key(self, mocker, monkeypatch): @@ -24,12 +24,12 @@ def test_initialization_without_api_key(self, mocker, monkeypatch): def test_call_method(self, cohere_llm, mocker): mock_llm = mocker.MagicMock() mock_llm.text = "test" - cohere_llm.client.chat.return_value = mock_llm + cohere_llm._client.chat.return_value = mock_llm llm_input = [Message(role="user", content="test")] result = cohere_llm(llm_input) assert isinstance(result, str), "Result should be a str" - cohere_llm.client.chat.assert_called_once() + cohere_llm._client.chat.assert_called_once() def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker): mocker.patch( @@ -46,7 +46,7 @@ def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker): def test_call_method_raises_error_on_api_failure(self, cohere_llm, mocker): mocker.patch.object( - cohere_llm.client, "__call__", side_effect=Exception("API call failed") + cohere_llm._client, "__call__", side_effect=Exception("API call failed") ) with pytest.raises(ValueError): cohere_llm("test")