From 9590841fe58388173840b657d05819811c8d0ed9 Mon Sep 17 00:00:00 2001 From: whimo Date: Fri, 25 Oct 2024 11:48:02 +0300 Subject: [PATCH] Support AzureOpenAI LLMs --- motleycrew/common/enums.py | 3 ++- motleycrew/common/llms.py | 46 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/motleycrew/common/enums.py b/motleycrew/common/enums.py index e945c2ee..0f5c6a67 100644 --- a/motleycrew/common/enums.py +++ b/motleycrew/common/enums.py @@ -8,8 +8,9 @@ class LLMProvider: TOGETHER = "together" GROQ = "groq" OLLAMA = "ollama" + AZURE_OPENAI = "azure_openai" - ALL = {OPENAI, ANTHROPIC, REPLICATE, TOGETHER, GROQ, OLLAMA} + ALL = {OPENAI, ANTHROPIC, REPLICATE, TOGETHER, GROQ, OLLAMA, AZURE_OPENAI} class LLMFramework: diff --git a/motleycrew/common/llms.py b/motleycrew/common/llms.py index 6f4765b7..a7b06fe2 100644 --- a/motleycrew/common/llms.py +++ b/motleycrew/common/llms.py @@ -1,7 +1,6 @@ """Helper functions to initialize Language Models (LLMs) from different frameworks.""" -from motleycrew.common import Defaults -from motleycrew.common import LLMProvider, LLMFramework +from motleycrew.common import Defaults, LLMFramework, LLMProvider from motleycrew.common.exceptions import LLMProviderNotSupported from motleycrew.common.utils import ensure_module_is_installed @@ -209,6 +208,47 @@ def llama_index_ollama_llm( return Ollama(model=llm_name, temperature=llm_temperature, **kwargs) +def langchain_azure_openai_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, +): + """Initialize an Azure OpenAI LLM client for use with Langchain. + + Args: + llm_name: Name of the LLM in Azure OpenAI API. + llm_temperature: Temperature for the LLM. + """ + from langchain_openai import AzureChatOpenAI + + return AzureChatOpenAI(model=llm_name, temperature=llm_temperature, **kwargs) + + +def llama_index_azure_openai_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, +): + """Initialize an Azure OpenAI LLM client for use with LlamaIndex. + + Args: + llm_name: Name of the LLM in Azure OpenAI API. + llm_temperature: Temperature for the LLM. + """ + ensure_module_is_installed("llama_index") + from llama_index.llms.azure_openai import AzureOpenAI + + if "azure_deployment" in kwargs: + kwargs["engine"] = kwargs.pop("azure_deployment") + + if "engine" not in kwargs: + raise ValueError( + "For using Azure OpenAI with LlamaIndex, you must specify an engine/deployment name." + ) + + return AzureOpenAI(model=llm_name, temperature=llm_temperature, **kwargs) + + LLM_MAP = { (LLMFramework.LANGCHAIN, LLMProvider.OPENAI): langchain_openai_llm, (LLMFramework.LLAMA_INDEX, LLMProvider.OPENAI): llama_index_openai_llm, @@ -222,6 +262,8 @@ def llama_index_ollama_llm( (LLMFramework.LLAMA_INDEX, LLMProvider.GROQ): llama_index_groq_llm, (LLMFramework.LANGCHAIN, LLMProvider.OLLAMA): langchain_ollama_llm, (LLMFramework.LLAMA_INDEX, LLMProvider.OLLAMA): llama_index_ollama_llm, + (LLMFramework.LANGCHAIN, LLMProvider.AZURE_OPENAI): langchain_azure_openai_llm, + (LLMFramework.LLAMA_INDEX, LLMProvider.AZURE_OPENAI): llama_index_azure_openai_llm, }