diff --git a/motleycrew/applications/research_agent/question_prioritizer.py b/motleycrew/applications/research_agent/question_prioritizer.py index 91f057f7..e98c6b00 100644 --- a/motleycrew/applications/research_agent/question_prioritizer.py +++ b/motleycrew/applications/research_agent/question_prioritizer.py @@ -1,7 +1,10 @@ +from typing import Optional + from langchain.prompts import PromptTemplate from langchain_core.prompts.base import BasePromptTemplate from langchain_core.runnables import RunnableLambda, RunnablePassthrough, chain from langchain_core.tools import StructuredTool +from langchain_core.language_models import BaseLanguageModel from pydantic import BaseModel, Field from motleycrew.applications.research_agent.question import Question @@ -16,8 +19,9 @@ class QuestionPrioritizerTool(MotleyTool): def __init__( self, prompt: str | BasePromptTemplate = None, + llm: Optional[BaseLanguageModel] = None, ): - langchain_tool = create_question_prioritizer_langchain_tool(prompt=prompt) + langchain_tool = create_question_prioritizer_langchain_tool(prompt=prompt, llm=llm) super().__init__(langchain_tool) @@ -47,6 +51,7 @@ class QuestionPrioritizerInput(BaseModel, arbitrary_types_allowed=True): def create_question_prioritizer_langchain_tool( prompt: str | BasePromptTemplate = None, + llm: Optional[BaseLanguageModel] = None, ) -> StructuredTool: if prompt is None: prompt = _default_prompt @@ -56,6 +61,7 @@ def create_question_prioritizer_langchain_tool( name="Question prioritizer", description="Takes the original question and a list of derived questions, " "and selects from the latter the one most pertinent to the former", + llm=llm, ) @chain diff --git a/motleycrew/applications/research_agent/question_task.py b/motleycrew/applications/research_agent/question_task.py index 039703a3..78c39ae8 100644 --- a/motleycrew/applications/research_agent/question_task.py +++ b/motleycrew/applications/research_agent/question_task.py @@ -41,7 +41,7 @@ def __init__( self.n_iter = 0 self.question = Question(question=question) self.graph_store.insert_node(self.question) - self.question_prioritization_tool = QuestionPrioritizerTool() + self.question_prioritization_tool = QuestionPrioritizerTool(llm=llm) self.question_generation_tool = QuestionGeneratorTool( query_tool=query_tool, graph=self.graph_store, llm=llm ) diff --git a/motleycrew/common/enums.py b/motleycrew/common/enums.py index e945c2ee..0f5c6a67 100644 --- a/motleycrew/common/enums.py +++ b/motleycrew/common/enums.py @@ -8,8 +8,9 @@ class LLMProvider: TOGETHER = "together" GROQ = "groq" OLLAMA = "ollama" + AZURE_OPENAI = "azure_openai" - ALL = {OPENAI, ANTHROPIC, REPLICATE, TOGETHER, GROQ, OLLAMA} + ALL = {OPENAI, ANTHROPIC, REPLICATE, TOGETHER, GROQ, OLLAMA, AZURE_OPENAI} class LLMFramework: diff --git a/motleycrew/common/llms.py b/motleycrew/common/llms.py index 6f4765b7..a7b06fe2 100644 --- a/motleycrew/common/llms.py +++ b/motleycrew/common/llms.py @@ -1,7 +1,6 @@ """Helper functions to initialize Language Models (LLMs) from different frameworks.""" -from motleycrew.common import Defaults -from motleycrew.common import LLMProvider, LLMFramework +from motleycrew.common import Defaults, LLMFramework, LLMProvider from motleycrew.common.exceptions import LLMProviderNotSupported from motleycrew.common.utils import ensure_module_is_installed @@ -209,6 +208,47 @@ def llama_index_ollama_llm( return Ollama(model=llm_name, temperature=llm_temperature, **kwargs) +def langchain_azure_openai_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, +): + """Initialize an Azure OpenAI LLM client for use with Langchain. + + Args: + llm_name: Name of the LLM in Azure OpenAI API. + llm_temperature: Temperature for the LLM. + """ + from langchain_openai import AzureChatOpenAI + + return AzureChatOpenAI(model=llm_name, temperature=llm_temperature, **kwargs) + + +def llama_index_azure_openai_llm( + llm_name: str = Defaults.DEFAULT_LLM_NAME, + llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, + **kwargs, +): + """Initialize an Azure OpenAI LLM client for use with LlamaIndex. + + Args: + llm_name: Name of the LLM in Azure OpenAI API. + llm_temperature: Temperature for the LLM. + """ + ensure_module_is_installed("llama_index") + from llama_index.llms.azure_openai import AzureOpenAI + + if "azure_deployment" in kwargs: + kwargs["engine"] = kwargs.pop("azure_deployment") + + if "engine" not in kwargs: + raise ValueError( + "For using Azure OpenAI with LlamaIndex, you must specify an engine/deployment name." + ) + + return AzureOpenAI(model=llm_name, temperature=llm_temperature, **kwargs) + + LLM_MAP = { (LLMFramework.LANGCHAIN, LLMProvider.OPENAI): langchain_openai_llm, (LLMFramework.LLAMA_INDEX, LLMProvider.OPENAI): llama_index_openai_llm, @@ -222,6 +262,8 @@ def llama_index_ollama_llm( (LLMFramework.LLAMA_INDEX, LLMProvider.GROQ): llama_index_groq_llm, (LLMFramework.LANGCHAIN, LLMProvider.OLLAMA): langchain_ollama_llm, (LLMFramework.LLAMA_INDEX, LLMProvider.OLLAMA): llama_index_ollama_llm, + (LLMFramework.LANGCHAIN, LLMProvider.AZURE_OPENAI): langchain_azure_openai_llm, + (LLMFramework.LLAMA_INDEX, LLMProvider.AZURE_OPENAI): llama_index_azure_openai_llm, } diff --git a/motleycrew/tools/simple_retriever_tool.py b/motleycrew/tools/simple_retriever_tool.py index 45652dfe..39969429 100644 --- a/motleycrew/tools/simple_retriever_tool.py +++ b/motleycrew/tools/simple_retriever_tool.py @@ -65,14 +65,16 @@ def make_retriever_langchain_tool( # load the documents and create the index documents = SimpleDirectoryReader(data_dir).load_data() index = VectorStoreIndex.from_documents( - documents, transformations=[SentenceSplitter(chunk_size=512), embeddings] + documents, + transformations=[SentenceSplitter(chunk_size=512), embeddings], + embed_model=embeddings, ) # store it for later index.storage_context.persist(persist_dir=persist_dir) else: # load the existing index storage_context = StorageContext.from_defaults(persist_dir=persist_dir) - index = load_index_from_storage(storage_context) + index = load_index_from_storage(storage_context, embed_model=embeddings) retriever = index.as_retriever( similarity_top_k=10,