From 14f33edf5b9469bce6b049f058a754529ce5e19c Mon Sep 17 00:00:00 2001 From: whimo Date: Sat, 31 Aug 2024 13:36:26 +0400 Subject: [PATCH 1/5] Support Replicate-hosted LLMs --- motleycrew/common/enums.py | 3 ++- motleycrew/common/llms.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/motleycrew/common/enums.py b/motleycrew/common/enums.py index 1fa438c4..e5533d56 100644 --- a/motleycrew/common/enums.py +++ b/motleycrew/common/enums.py @@ -4,8 +4,9 @@ class LLMFamily: OPENAI = "openai" ANTHROPIC = "anthropic" + REPLICATE = "replicate" - ALL = {OPENAI, ANTHROPIC} + ALL = {OPENAI, ANTHROPIC, REPLICATE} class LLMFramework: diff --git a/motleycrew/common/llms.py b/motleycrew/common/llms.py index cef8a242..da9efd9f 100644 --- a/motleycrew/common/llms.py +++ b/motleycrew/common/llms.py @@ -74,11 +74,49 @@ def llama_index_anthropic_llm( return Anthropic(model=llm_name, temperature=llm_temperature, **kwargs) +def langchain_replicate_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, +): + """Initialize a Replicate LLM client for use with Langchain. + + Args: + llm_name: Name of the LLM in Replicate API. + llm_temperature: Temperature for the LLM. + """ + from langchain_community.llms import Replicate + + model_kwargs = kwargs.pop("model_kwargs", {}) + model_kwargs["temperature"] = llm_temperature + + return Replicate(model=llm_name, model_kwargs=model_kwargs, **kwargs) + + +def llama_index_replicate_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, +): + """Initialize a Replicate LLM client for use with LlamaIndex. + + Args: + llm_name: Name of the LLM in Replicate API. + llm_temperature: Temperature for the LLM. + """ + ensure_module_is_installed("llama_index") + from llama_index.llms.replicate import Replicate + + return Replicate(model=llm_name, temperature=llm_temperature, **kwargs) + + Defaults.LLM_MAP = { (LLMFramework.LANGCHAIN, LLMFamily.OPENAI): langchain_openai_llm, (LLMFramework.LLAMA_INDEX, LLMFamily.OPENAI): llama_index_openai_llm, (LLMFramework.LANGCHAIN, LLMFamily.ANTHROPIC): langchain_anthropic_llm, (LLMFramework.LLAMA_INDEX, LLMFamily.ANTHROPIC): llama_index_anthropic_llm, + (LLMFramework.LANGCHAIN, LLMFamily.REPLICATE): langchain_replicate_llm, + (LLMFramework.LLAMA_INDEX, LLMFamily.REPLICATE): llama_index_replicate_llm, } From ab38c7b1c77cab6847a3ba9b4813d980ecde8ee3 Mon Sep 17 00:00:00 2001 From: whimo Date: Mon, 2 Sep 2024 13:18:52 +0400 Subject: [PATCH 2/5] Add support for Together, Groq, and Ollama LLMs in Langchain and LlamaIndex integration. --- motleycrew/common/enums.py | 5 +- motleycrew/common/llms.py | 105 +++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/motleycrew/common/enums.py b/motleycrew/common/enums.py index e5533d56..e5a63657 100644 --- a/motleycrew/common/enums.py +++ b/motleycrew/common/enums.py @@ -5,8 +5,11 @@ class LLMFamily: OPENAI = "openai" ANTHROPIC = "anthropic" REPLICATE = "replicate" + TOGETHER = "together" + GROQ = "groq" + OLLAMA = "ollama" - ALL = {OPENAI, ANTHROPIC, REPLICATE} + ALL = {OPENAI, ANTHROPIC, REPLICATE, TOGETHER, GROQ, OLLAMA} class LLMFramework: diff --git a/motleycrew/common/llms.py b/motleycrew/common/llms.py index da9efd9f..d7d925fd 100644 --- a/motleycrew/common/llms.py +++ b/motleycrew/common/llms.py @@ -110,6 +110,105 @@ def llama_index_replicate_llm( return Replicate(model=llm_name, temperature=llm_temperature, **kwargs) +def langchain_together_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, +): + """Initialize a Together LLM client for use with Langchain. + + Args: + llm_name: Name of the LLM in Together API. + llm_temperature: Temperature for the LLM. + """ + from langchain_together import ChatTogether + + return ChatTogether(model=llm_name, temperature=llm_temperature, **kwargs) + + +def llama_index_together_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, +): + """Initialize a Together LLM client for use with LlamaIndex. + + Args: + llm_name: Name of the LLM in Together API. + llm_temperature: Temperature for the LLM. + """ + ensure_module_is_installed("llama_index") + from llama_index.llms.together import TogetherLLM + + return TogetherLLM(model=llm_name, temperature=llm_temperature, **kwargs) + + +def langchain_groq_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, +): + """Initialize a Groq LLM client for use with Langchain. + + Args: + llm_name: Name of the LLM in Groq API. + llm_temperature: Temperature for the LLM. + """ + from langchain_groq import ChatGroq + + return ChatGroq(model=llm_name, temperature=llm_temperature, **kwargs) + + +def llama_index_groq_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, +): + """Initialize a Groq LLM client for use with LlamaIndex. + + Args: + llm_name: Name of the LLM in Groq API. + llm_temperature: Temperature for the LLM. + """ + ensure_module_is_installed("llama_index") + from llama_index.llms.groq import Groq + + return Groq(model=llm_name, temperature=llm_temperature, **kwargs) + + +def langchain_ollama_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, +): + """Initialize an Ollama LLM client for use with Langchain. + + Args: + llm_name: Name of the LLM in Ollama API. + llm_temperature: Temperature for the LLM. + """ + from langchain_ollama.llms import OllamaLLM + + return OllamaLLM(model=llm_name, temperature=llm_temperature, **kwargs) + + +def llama_index_ollama_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, +): + """Initialize an Ollama LLM client for use with LlamaIndex. + + Args: + llm_name: Name of the LLM in Ollama API. + llm_temperature: Temperature for the LLM. + """ + ensure_module_is_installed("llama_index") + from llama_index.llms.ollama import Ollama + + return Ollama(model=llm_name, temperature=llm_temperature, **kwargs) + + Defaults.LLM_MAP = { (LLMFramework.LANGCHAIN, LLMFamily.OPENAI): langchain_openai_llm, (LLMFramework.LLAMA_INDEX, LLMFamily.OPENAI): llama_index_openai_llm, @@ -117,6 +216,12 @@ def llama_index_replicate_llm( (LLMFramework.LLAMA_INDEX, LLMFamily.ANTHROPIC): llama_index_anthropic_llm, (LLMFramework.LANGCHAIN, LLMFamily.REPLICATE): langchain_replicate_llm, (LLMFramework.LLAMA_INDEX, LLMFamily.REPLICATE): llama_index_replicate_llm, + (LLMFramework.LANGCHAIN, LLMFamily.TOGETHER): langchain_together_llm, + (LLMFramework.LLAMA_INDEX, LLMFamily.TOGETHER): llama_index_together_llm, + (LLMFramework.LANGCHAIN, LLMFamily.GROQ): langchain_groq_llm, + (LLMFramework.LLAMA_INDEX, LLMFamily.GROQ): llama_index_groq_llm, + (LLMFramework.LANGCHAIN, LLMFamily.OLLAMA): langchain_ollama_llm, + (LLMFramework.LLAMA_INDEX, LLMFamily.OLLAMA): llama_index_ollama_llm, } From 1b64eea399b47ea3678f0ebe84be226f33dd2fff Mon Sep 17 00:00:00 2001 From: whimo Date: Mon, 2 Sep 2024 14:04:24 +0400 Subject: [PATCH 3/5] Replace `OllamaLLM` with `ChatOllama` in `langchain_ollama_llm` function. --- motleycrew/common/llms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/motleycrew/common/llms.py b/motleycrew/common/llms.py index d7d925fd..cf7d6aa6 100644 --- a/motleycrew/common/llms.py +++ b/motleycrew/common/llms.py @@ -187,9 +187,9 @@ def langchain_ollama_llm( llm_name: Name of the LLM in Ollama API. llm_temperature: Temperature for the LLM. """ - from langchain_ollama.llms import OllamaLLM + from langchain_ollama.llms import ChatOllama - return OllamaLLM(model=llm_name, temperature=llm_temperature, **kwargs) + return ChatOllama(model=llm_name, temperature=llm_temperature, **kwargs) def llama_index_ollama_llm( From b3d370983a373dec45197502c6821f2c11822b3b Mon Sep 17 00:00:00 2001 From: whimo Date: Mon, 2 Sep 2024 14:06:02 +0400 Subject: [PATCH 4/5] Fix import path for ChatOllama in langchain_ollama_llm function. --- motleycrew/common/llms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/motleycrew/common/llms.py b/motleycrew/common/llms.py index cf7d6aa6..15e0d9e6 100644 --- a/motleycrew/common/llms.py +++ b/motleycrew/common/llms.py @@ -187,7 +187,7 @@ def langchain_ollama_llm( llm_name: Name of the LLM in Ollama API. llm_temperature: Temperature for the LLM. """ - from langchain_ollama.llms import ChatOllama + from langchain_ollama.chat_models import ChatOllama return ChatOllama(model=llm_name, temperature=llm_temperature, **kwargs) From 5c2a59c7c827c98f7a6c9a4672769c2c3f187903 Mon Sep 17 00:00:00 2001 From: whimo Date: Fri, 6 Sep 2024 13:14:10 +0400 Subject: [PATCH 5/5] LLM-related docs & refactoring --- docs/source/another_llm.nblink | 3 - docs/source/choosing_llms.rst | 108 ++++++++++++++++++++++ docs/source/usage.rst | 2 +- examples/Using another LLM.ipynb | 152 ------------------------------- motleycrew/common/__init__.py | 6 +- motleycrew/common/defaults.py | 5 +- motleycrew/common/enums.py | 2 +- motleycrew/common/exceptions.py | 10 +- motleycrew/common/llms.py | 38 ++++---- tests/test_agents/test_llms.py | 18 ++-- 10 files changed, 147 insertions(+), 197 deletions(-) delete mode 100644 docs/source/another_llm.nblink create mode 100644 docs/source/choosing_llms.rst delete mode 100644 examples/Using another LLM.ipynb diff --git a/docs/source/another_llm.nblink b/docs/source/another_llm.nblink deleted file mode 100644 index 8b31a621..00000000 --- a/docs/source/another_llm.nblink +++ /dev/null @@ -1,3 +0,0 @@ -{ - "path": "../../examples/Using another LLM.ipynb" -} \ No newline at end of file diff --git a/docs/source/choosing_llms.rst b/docs/source/choosing_llms.rst new file mode 100644 index 00000000..930d5b97 --- /dev/null +++ b/docs/source/choosing_llms.rst @@ -0,0 +1,108 @@ +Choosing LLMs +==================== + +Generally, the interaction with an LLM is up to the agent implementation. +However, as motleycrew integrates with several agent frameworks, there is some common ground for how to choose LLMs. + + +Providing an LLM to an agent +---------------------------- + +In general, you can pass a specific LLM to the agent you're using. + +.. code-block:: python + + from motleycrew.agents.langchain import ReActToolCallingMotleyAgent + from langchain_openai import ChatOpenAI + + llm = ChatOpenAI(model="gpt-4o", temperature=0) + agent = ReActToolCallingMotleyAgent(llm=llm, tools=[...]) + + +The LLM class depends on the agent framework you're using. +That's why we have an ``init_llm`` function to help you set up the LLM. + +.. code-block:: python + + from motleycrew.common.llms import init_llm + from motleycrew.common import LLMFramework, LLMProvider + + llm = init_llm( + llm_framework=LLMFramework.LANGCHAIN, + llm_provider=LLMProvider.ANTHROPIC, + llm_name="claude-3-5-sonnet-20240620", + llm_temperature=0 + ) + agent = ReActToolCallingMotleyAgent(llm=llm, tools=[...]) + + +The currently supported frameworks (:py:class:`motleycrew.common.enums.LLMFramework`) are: + +- :py:class:`Langchain ` for Langchain-based agents from Langchain, CrewAI, motelycrew etc. +- :py:class:`LlamaIndex ` for LlamaIndex-based agents. + +The currently supported LLM providers (:py:class:`motleycrew.common.enums.LLMProvider`) are: + +- :py:class:`OpenAI ` +- :py:class:`Anthropic ` +- :py:class:`Groq ` +- :py:class:`Together ` +- :py:class:`Replicate ` +- :py:class:`Ollama ` + +Please raise an issue if you need to add support for another LLM provider. + + +Default LLM +----------- + +At present, we default to OpenAI's latest ``gpt-4o`` model for our agents, +and rely on the user to set the `OPENAI_API_KEY` environment variable. + +You can control the default LLM as follows: + +.. code-block:: python + + from motleycrew.common import Defaults + Defaults.DEFAULT_LLM_PROVIDE = "the_new_default_LLM_provider" + Defaults.DEFAULT_LLM_NAME = "name_of_the_new_default_model_from_the_provider" + + +Using custom LLMs +----------------- + +To use a custom LLM provider to use as the default or via the ``init_llm`` function, +you need to make sure that for all the frameworks you're using (currently at most Langchain, LlamaIndex), +the `LLM_MAP` has an entry for the LLM provider, for example as follows: + +.. code-block:: python + + from motleycrew.common import LLMProvider + from motleycrew.common.llms import LLM_MAP + + LLM_MAP[(LLMFramework.LANGCHAIN, "MyLLMProvider")] = my_langchain_llm_factory + LLM_MAP[(LLMFramework.LLAMA_INDEX, "MyLLMProvider")] = my_llamaindex_llm_factory + +Here each llm factory is a function with a signature +``def llm_factory(llm_name: str, llm_temperature: float, **kwargs)`` that returns the model object for the relevant framework. + +For example, this is the built-in OpenAI model factory for Langchain: + +.. code-block:: python + + def langchain_openai_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, + ): + from langchain_openai import ChatOpenAI + + return ChatOpenAI(model=llm_name, temperature=llm_temperature, **kwargs) + + +You can also overwrite the `LLM_MAP` values for e.g. the OpenAI models if, for example, +you want to use an in-house wrapper for Langchain or Llamaindex model adapters +(for example, to use an internal gateway instead of directly hitting the OpenAI endpoints). + +Note that at present, if you use Autogen with motleycrew, you will need to separately control +the models that Autogen uses, using the Autogen-specific APIs. diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 904582a7..b3f86c34 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -9,5 +9,5 @@ Usage key_concepts agents knowledge_graph - another_llm + choosing_llms caching_observability diff --git a/examples/Using another LLM.ipynb b/examples/Using another LLM.ipynb deleted file mode 100644 index b95effd9..00000000 --- a/examples/Using another LLM.ipynb +++ /dev/null @@ -1,152 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "84dd9af7-ba68-429e-b9b2-eb2ab8ee54fc", - "metadata": {}, - "source": [ - "# Using another LLM" - ] - }, - { - "cell_type": "markdown", - "id": "860c41f2-d1aa-4cc9-a92e-bec59043c5db", - "metadata": {}, - "source": [ - "At present, we default to using OpenAI models, and rely on the user to set the `OPENAI_API_KEY` environment variable.\n", - "\n", - "You can control the default LLM as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fe153839-f056-4159-98d4-4f366184e78f", - "metadata": {}, - "outputs": [], - "source": [ - "from motleycrew.common import Defaults\n", - "Defaults.DEFAULT_LLM_FAMILY = the_new_default_LLM_family\n", - "Defaults.DEFAULT_LLM_NAME = name_of_the_new_default_model_within_the_family" - ] - }, - { - "cell_type": "markdown", - "id": "a9e7dad1-490e-4912-9b05-d2cbad00d1f1", - "metadata": {}, - "source": [ - "For this to work, you must make sure that for all the frameworks you're using (currently at most Langchain, LlamaIndex), the `LLM_MAP` has an entry for the new default LLM family, for example as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10fc3846-d14c-424a-a61a-f956dff2dbd4", - "metadata": {}, - "outputs": [], - "source": [ - "Defaults.LLM_MAP[ (LLMFramework.LANGCHAIN, \"MyLLMFamily\") ] = my_langchain_llm_factory\n", - "Defaults.LLM_MAP[ (LLMFramework.LLAMA_INDEX, \"MyLLMFamily\") ] = my_llamaindex_llm_factory" - ] - }, - { - "cell_type": "markdown", - "id": "27e97371-5f3f-42f1-a321-363d62d42080", - "metadata": {}, - "source": [ - "Here each llm factory is a function with a signature ```def llm_factory(llm_name: str, llm_temperature: float, **kwargs)``` that returns the model object for the relevant framework. \n", - "\n", - "For example, this is the built-in OpenAI model factory for Langchain:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a2ea0a52-a4c9-4c14-b744-1078ff48012b", - "metadata": {}, - "outputs": [], - "source": [ - "def langchain_openai_llm(\n", - " llm_name: str = Defaults.DEFAULT_LLM_NAME,\n", - " llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE,\n", - " **kwargs,\n", - "):\n", - " from langchain_openai import ChatOpenAI\n", - "\n", - " return ChatOpenAI(model=llm_name, temperature=llm_temperature, **kwargs)\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "id": "84f61ad1-96d0-410d-9349-48addaf85aa9", - "metadata": {}, - "source": [ - "and here is the one for OpenAI and LlamaIndex:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17067877-d99c-483c-a23b-a738790ea45a", - "metadata": {}, - "outputs": [], - "source": [ - "def llama_index_openai_llm(\n", - " llm_name: str = Defaults.DEFAULT_LLM_NAME,\n", - " llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE,\n", - " **kwargs,\n", - "):\n", - " ensure_module_is_installed(\"llama_index\")\n", - " from llama_index.llms.openai import OpenAI\n", - "\n", - " return OpenAI(model=llm_name, temperature=llm_temperature, **kwargs)" - ] - }, - { - "cell_type": "markdown", - "id": "6b36baf8-5491-49fa-8ea3-0c33501af932", - "metadata": {}, - "source": [ - "You can also overwrite the `LLM_MAP` values for e.g. the OpenAI models if, for example, you want to use an in-house wrapper for Langchain or Llamaindex model adapters (for example, to use an internal gateway instead of directly hitting the OpenAI endpoints) " - ] - }, - { - "cell_type": "markdown", - "id": "9884f013-63f0-44a5-936b-6bc81909b94e", - "metadata": {}, - "source": [ - "Note that at present, if you use Autogen with motleycrew, you will need to separately control the models that Autogen uses, using the Autogen-specific APIs." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9883c816-4f22-4ee9-87db-53ab7a802571", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/motleycrew/common/__init__.py b/motleycrew/common/__init__.py index 3011bc95..b1d9f4d8 100644 --- a/motleycrew/common/__init__.py +++ b/motleycrew/common/__init__.py @@ -3,14 +3,12 @@ from .defaults import Defaults from .enums import AsyncBackend from .enums import GraphStoreType -from .enums import LLMFamily from .enums import LLMFramework +from .enums import LLMProvider from .enums import LunaryEventName from .enums import LunaryRunType from .enums import TaskUnitStatus - from .logging import logger, configure_logging - from .types import MotleyAgentFactory from .types import MotleySupportedTool @@ -22,7 +20,7 @@ "configure_logging", "AsyncBackend", "GraphStoreType", - "LLMFamily", + "LLMProvider", "LLMFramework", "LunaryEventName", "LunaryRunType", diff --git a/motleycrew/common/defaults.py b/motleycrew/common/defaults.py index 5081336f..3618dd23 100644 --- a/motleycrew/common/defaults.py +++ b/motleycrew/common/defaults.py @@ -1,15 +1,14 @@ from motleycrew.common.enums import GraphStoreType -from motleycrew.common.enums import LLMFamily +from motleycrew.common.enums import LLMProvider class Defaults: """Default values for various settings.""" DEFAULT_REACT_AGENT_MAX_ITERATIONS = 15 - DEFAULT_LLM_FAMILY = LLMFamily.OPENAI + DEFAULT_LLM_PROVIDER = LLMProvider.OPENAI DEFAULT_LLM_NAME = "gpt-4o" DEFAULT_LLM_TEMPERATURE = 0.0 - LLM_MAP = {} DEFAULT_GRAPH_STORE_TYPE = GraphStoreType.KUZU diff --git a/motleycrew/common/enums.py b/motleycrew/common/enums.py index e5a63657..b4c0fbef 100644 --- a/motleycrew/common/enums.py +++ b/motleycrew/common/enums.py @@ -1,7 +1,7 @@ """Various enums used in the project.""" -class LLMFamily: +class LLMProvider: OPENAI = "openai" ANTHROPIC = "anthropic" REPLICATE = "replicate" diff --git a/motleycrew/common/exceptions.py b/motleycrew/common/exceptions.py index a02bc127..5d7dd432 100644 --- a/motleycrew/common/exceptions.py +++ b/motleycrew/common/exceptions.py @@ -5,15 +5,15 @@ from motleycrew.common import Defaults -class LLMFamilyNotSupported(Exception): - """Raised when an LLM family is not supported in motleycrew via a framework.""" +class LLMProviderNotSupported(Exception): + """Raised when an LLM provider is not supported in motleycrew via a framework.""" - def __init__(self, llm_framework: str, llm_family: str): + def __init__(self, llm_framework: str, llm_provider: str): self.llm_framework = llm_framework - self.llm_family = llm_family + self.llm_provider = llm_provider def __str__(self) -> str: - return f"LLM family `{self.llm_family}` is not supported via the framework `{self.llm_framework}`" + return f"LLM provider `{self.llm_provider}` is not supported via the framework `{self.llm_framework}`" class LLMFrameworkNotSupported(Exception): diff --git a/motleycrew/common/llms.py b/motleycrew/common/llms.py index 15e0d9e6..6f4765b7 100644 --- a/motleycrew/common/llms.py +++ b/motleycrew/common/llms.py @@ -1,8 +1,8 @@ """Helper functions to initialize Language Models (LLMs) from different frameworks.""" from motleycrew.common import Defaults -from motleycrew.common import LLMFamily, LLMFramework -from motleycrew.common.exceptions import LLMFamilyNotSupported, LLMFrameworkNotSupported +from motleycrew.common import LLMProvider, LLMFramework +from motleycrew.common.exceptions import LLMProviderNotSupported from motleycrew.common.utils import ensure_module_is_installed @@ -209,25 +209,25 @@ def llama_index_ollama_llm( return Ollama(model=llm_name, temperature=llm_temperature, **kwargs) -Defaults.LLM_MAP = { - (LLMFramework.LANGCHAIN, LLMFamily.OPENAI): langchain_openai_llm, - (LLMFramework.LLAMA_INDEX, LLMFamily.OPENAI): llama_index_openai_llm, - (LLMFramework.LANGCHAIN, LLMFamily.ANTHROPIC): langchain_anthropic_llm, - (LLMFramework.LLAMA_INDEX, LLMFamily.ANTHROPIC): llama_index_anthropic_llm, - (LLMFramework.LANGCHAIN, LLMFamily.REPLICATE): langchain_replicate_llm, - (LLMFramework.LLAMA_INDEX, LLMFamily.REPLICATE): llama_index_replicate_llm, - (LLMFramework.LANGCHAIN, LLMFamily.TOGETHER): langchain_together_llm, - (LLMFramework.LLAMA_INDEX, LLMFamily.TOGETHER): llama_index_together_llm, - (LLMFramework.LANGCHAIN, LLMFamily.GROQ): langchain_groq_llm, - (LLMFramework.LLAMA_INDEX, LLMFamily.GROQ): llama_index_groq_llm, - (LLMFramework.LANGCHAIN, LLMFamily.OLLAMA): langchain_ollama_llm, - (LLMFramework.LLAMA_INDEX, LLMFamily.OLLAMA): llama_index_ollama_llm, +LLM_MAP = { + (LLMFramework.LANGCHAIN, LLMProvider.OPENAI): langchain_openai_llm, + (LLMFramework.LLAMA_INDEX, LLMProvider.OPENAI): llama_index_openai_llm, + (LLMFramework.LANGCHAIN, LLMProvider.ANTHROPIC): langchain_anthropic_llm, + (LLMFramework.LLAMA_INDEX, LLMProvider.ANTHROPIC): llama_index_anthropic_llm, + (LLMFramework.LANGCHAIN, LLMProvider.REPLICATE): langchain_replicate_llm, + (LLMFramework.LLAMA_INDEX, LLMProvider.REPLICATE): llama_index_replicate_llm, + (LLMFramework.LANGCHAIN, LLMProvider.TOGETHER): langchain_together_llm, + (LLMFramework.LLAMA_INDEX, LLMProvider.TOGETHER): llama_index_together_llm, + (LLMFramework.LANGCHAIN, LLMProvider.GROQ): langchain_groq_llm, + (LLMFramework.LLAMA_INDEX, LLMProvider.GROQ): llama_index_groq_llm, + (LLMFramework.LANGCHAIN, LLMProvider.OLLAMA): langchain_ollama_llm, + (LLMFramework.LLAMA_INDEX, LLMProvider.OLLAMA): llama_index_ollama_llm, } def init_llm( llm_framework: str, - llm_family: str = Defaults.DEFAULT_LLM_FAMILY, + llm_provider: str = Defaults.DEFAULT_LLM_PROVIDER, llm_name: str = Defaults.DEFAULT_LLM_NAME, llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, **kwargs, @@ -236,13 +236,13 @@ def init_llm( Args: llm_framework: Framework of the LLM client. - llm_family: Family of the LLM. + llm_provider: Provider of the LLM. llm_name: Name of the LLM. llm_temperature: Temperature for the LLM. """ - func = Defaults.LLM_MAP.get((llm_framework, llm_family), None) + func = LLM_MAP.get((llm_framework, llm_provider), None) if func is not None: return func(llm_name=llm_name, llm_temperature=llm_temperature, **kwargs) - raise LLMFamilyNotSupported(llm_framework=llm_framework, llm_family=llm_family) + raise LLMProviderNotSupported(llm_framework=llm_framework, llm_provider=llm_provider) diff --git a/tests/test_agents/test_llms.py b/tests/test_agents/test_llms.py index 49a1d215..3ef0babd 100644 --- a/tests/test_agents/test_llms.py +++ b/tests/test_agents/test_llms.py @@ -2,23 +2,23 @@ from langchain_openai import ChatOpenAI from llama_index.llms.openai import OpenAI -from motleycrew.common import LLMFamily, LLMFramework -from motleycrew.common.exceptions import LLMFamilyNotSupported +from motleycrew.common import LLMProvider, LLMFramework +from motleycrew.common.exceptions import LLMProviderNotSupported from motleycrew.common.llms import init_llm @pytest.mark.parametrize( - "llm_family, llm_framework, expected_class", + "llm_provider, llm_framework, expected_class", [ - (LLMFamily.OPENAI, LLMFramework.LANGCHAIN, ChatOpenAI), - (LLMFamily.OPENAI, LLMFramework.LLAMA_INDEX, OpenAI), + (LLMProvider.OPENAI, LLMFramework.LANGCHAIN, ChatOpenAI), + (LLMProvider.OPENAI, LLMFramework.LLAMA_INDEX, OpenAI), ], ) -def test_init_llm(llm_family, llm_framework, expected_class): - llm = init_llm(llm_family=llm_family, llm_framework=llm_framework) +def test_init_llm(llm_provider, llm_framework, expected_class): + llm = init_llm(llm_provider=llm_provider, llm_framework=llm_framework) assert isinstance(llm, expected_class) def test_raise_init_llm(): - with pytest.raises(LLMFamilyNotSupported): - llm = init_llm(llm_family=LLMFamily.OPENAI, llm_framework="unknown_framework") + with pytest.raises(LLMProviderNotSupported): + llm = init_llm(llm_provider=LLMProvider.OPENAI, llm_framework="unknown_framework")