Skip to content

Commit

Permalink
Support AzureOpenAI LLMs
Browse files Browse the repository at this point in the history
  • Loading branch information
whimo committed Oct 25, 2024
1 parent a9e6340 commit 9590841
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
3 changes: 2 additions & 1 deletion motleycrew/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ class LLMProvider:
TOGETHER = "together"
GROQ = "groq"
OLLAMA = "ollama"
AZURE_OPENAI = "azure_openai"

ALL = {OPENAI, ANTHROPIC, REPLICATE, TOGETHER, GROQ, OLLAMA}
ALL = {OPENAI, ANTHROPIC, REPLICATE, TOGETHER, GROQ, OLLAMA, AZURE_OPENAI}


class LLMFramework:
Expand Down
46 changes: 44 additions & 2 deletions motleycrew/common/llms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Helper functions to initialize Language Models (LLMs) from different frameworks."""

from motleycrew.common import Defaults
from motleycrew.common import LLMProvider, LLMFramework
from motleycrew.common import Defaults, LLMFramework, LLMProvider
from motleycrew.common.exceptions import LLMProviderNotSupported
from motleycrew.common.utils import ensure_module_is_installed

Expand Down Expand Up @@ -209,6 +208,47 @@ def llama_index_ollama_llm(
return Ollama(model=llm_name, temperature=llm_temperature, **kwargs)


def langchain_azure_openai_llm(
llm_name: str = Defaults.DEFAULT_LLM_NAME,
llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE,
**kwargs,
):
"""Initialize an Azure OpenAI LLM client for use with Langchain.
Args:
llm_name: Name of the LLM in Azure OpenAI API.
llm_temperature: Temperature for the LLM.
"""
from langchain_openai import AzureChatOpenAI

return AzureChatOpenAI(model=llm_name, temperature=llm_temperature, **kwargs)


def llama_index_azure_openai_llm(
llm_name: str = Defaults.DEFAULT_LLM_NAME,
llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE,
**kwargs,
):
"""Initialize an Azure OpenAI LLM client for use with LlamaIndex.
Args:
llm_name: Name of the LLM in Azure OpenAI API.
llm_temperature: Temperature for the LLM.
"""
ensure_module_is_installed("llama_index")
from llama_index.llms.azure_openai import AzureOpenAI

if "azure_deployment" in kwargs:
kwargs["engine"] = kwargs.pop("azure_deployment")

if "engine" not in kwargs:
raise ValueError(
"For using Azure OpenAI with LlamaIndex, you must specify an engine/deployment name."
)

return AzureOpenAI(model=llm_name, temperature=llm_temperature, **kwargs)


LLM_MAP = {
(LLMFramework.LANGCHAIN, LLMProvider.OPENAI): langchain_openai_llm,
(LLMFramework.LLAMA_INDEX, LLMProvider.OPENAI): llama_index_openai_llm,
Expand All @@ -222,6 +262,8 @@ def llama_index_ollama_llm(
(LLMFramework.LLAMA_INDEX, LLMProvider.GROQ): llama_index_groq_llm,
(LLMFramework.LANGCHAIN, LLMProvider.OLLAMA): langchain_ollama_llm,
(LLMFramework.LLAMA_INDEX, LLMProvider.OLLAMA): llama_index_ollama_llm,
(LLMFramework.LANGCHAIN, LLMProvider.AZURE_OPENAI): langchain_azure_openai_llm,
(LLMFramework.LLAMA_INDEX, LLMProvider.AZURE_OPENAI): llama_index_azure_openai_llm,
}


Expand Down

0 comments on commit 9590841

Please sign in to comment.