Skip to content

Commit

Permalink
Support custom LLMs in research agent (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
whimo authored Sep 27, 2024
1 parent 246c683 commit 260ff53
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 421 deletions.
619 changes: 219 additions & 400 deletions examples/Multi-step research agent.ipynb

Large diffs are not rendered by default.

26 changes: 17 additions & 9 deletions examples/research_agent/research_agent_main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from pathlib import Path
import shutil
import os
import platform
import shutil
from pathlib import Path

import kuzu
from dotenv import load_dotenv

from motleycrew import MotleyCrew
from motleycrew.storage import MotleyKuzuGraphStore
from motleycrew.common import configure_logging
from motleycrew.applications.research_agent.question_task import QuestionTask
from motleycrew.applications.research_agent.answer_task import AnswerTask

from motleycrew.applications.research_agent.question_task import QuestionTask
from motleycrew.common import LLMFramework, configure_logging
from motleycrew.common.llms import init_llm
from motleycrew.storage import MotleyKuzuGraphStore
from motleycrew.tools.simple_retriever_tool import SimpleRetrieverTool


WORKING_DIR = Path(__file__).parent
if "Dropbox" in WORKING_DIR.parts and platform.system() == "Windows":
# On Windows, kuzu has file locking issues with Dropbox
Expand All @@ -31,21 +30,30 @@


def main():
llm = init_llm(
llm_framework=LLMFramework.LANGCHAIN
) # throughout this project, we use LangChain's LLM wrappers

load_dotenv()
configure_logging(verbose=True)

shutil.rmtree(DB_PATH)

# You can pass any LlamaIndex embedding to the retriever tool, default is OpenAI's text-embedding-ada-002
query_tool = SimpleRetrieverTool(DATA_DIR, PERSIST_DIR, return_strings_only=True)

db = kuzu.Database(DB_PATH)
graph_store = MotleyKuzuGraphStore(db)
crew = MotleyCrew(graph_store=graph_store)

question_task = QuestionTask(
crew=crew, question=QUESTION, query_tool=query_tool, max_iter=MAX_ITER
crew=crew,
question=QUESTION,
query_tool=query_tool,
max_iter=MAX_ITER,
llm=llm,
)
answer_task = AnswerTask(answer_length=ANSWER_LENGTH, crew=crew)
answer_task = AnswerTask(answer_length=ANSWER_LENGTH, crew=crew, llm=llm)

question_task >> answer_task

Expand Down
4 changes: 3 additions & 1 deletion motleycrew/applications/research_agent/answer_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional

from langchain_core.runnables import Runnable
from langchain_core.language_models import BaseLanguageModel

from motleycrew.applications.research_agent.question import Question
from motleycrew.applications.research_agent.question_answerer import AnswerSubQuestionTool
Expand All @@ -21,6 +22,7 @@ def __init__(
self,
crew: MotleyCrew,
answer_length: int = 1000,
llm: Optional[BaseLanguageModel] = None,
):
super().__init__(
name="AnswerTask",
Expand All @@ -30,7 +32,7 @@ def __init__(
)
self.answer_length = answer_length
self.answerer = AnswerSubQuestionTool(
graph=self.graph_store, answer_length=self.answer_length
graph=self.graph_store, answer_length=self.answer_length, llm=llm
)

def get_next_unit(self) -> QuestionAnsweringTaskUnit | None:
Expand Down
15 changes: 9 additions & 6 deletions motleycrew/applications/research_agent/question_answerer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import Optional

from langchain.prompts import PromptTemplate
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import (
RunnablePassthrough,
RunnableLambda,
chain,
)
from langchain_core.runnables import RunnableLambda, RunnablePassthrough, chain
from langchain_core.tools import Tool

from motleycrew.applications.research_agent.question import Question
from motleycrew.common.utils import print_passthrough
from motleycrew.storage import MotleyGraphStore
from motleycrew.tools import MotleyTool, LLMTool
from motleycrew.tools import LLMTool, MotleyTool

_default_prompt = PromptTemplate.from_template(
"""
Expand All @@ -37,11 +36,13 @@ def __init__(
graph: MotleyGraphStore,
answer_length: int,
prompt: str | BasePromptTemplate = None,
llm: Optional[BaseLanguageModel] = None,
):
langchain_tool = create_answer_question_langchain_tool(
graph=graph,
answer_length=answer_length,
prompt=prompt,
llm=llm,
)

super().__init__(langchain_tool)
Expand Down Expand Up @@ -70,6 +71,7 @@ def create_answer_question_langchain_tool(
graph: MotleyGraphStore,
answer_length: int,
prompt: str | BasePromptTemplate = None,
llm: Optional[BaseLanguageModel] = None,
) -> Tool:
if prompt is None:
prompt = _default_prompt
Expand All @@ -78,6 +80,7 @@ def create_answer_question_langchain_tool(
prompt=prompt.partial(answer_length=str(answer_length)),
name="Question answerer",
description="Tool to answer a question from notes and sub-questions",
llm=llm,
)
"""
Gets a valid question node ID, question, and context as input dict
Expand Down
4 changes: 3 additions & 1 deletion motleycrew/applications/research_agent/question_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional

from langchain_core.language_models import BaseLanguageModel
from langchain_core.runnables import Runnable

from motleycrew.common import logger
Expand All @@ -26,6 +27,7 @@ def __init__(
crew: MotleyCrew,
max_iter: int = 10,
allow_async_units: bool = False,
llm: Optional[BaseLanguageModel] = None,
name: str = "QuestionTask",
):
super().__init__(
Expand All @@ -41,7 +43,7 @@ def __init__(
self.graph_store.insert_node(self.question)
self.question_prioritization_tool = QuestionPrioritizerTool()
self.question_generation_tool = QuestionGeneratorTool(
query_tool=query_tool, graph=self.graph_store
query_tool=query_tool, graph=self.graph_store, llm=llm
)

def get_next_unit(self) -> QuestionGenerationTaskUnit | None:
Expand Down
16 changes: 12 additions & 4 deletions motleycrew/tools/simple_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
load_index_from_storage,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.embeddings import BaseEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding

from motleycrew.applications.research_agent.question import Question
Expand All @@ -26,6 +27,7 @@ def __init__(
return_strings_only: bool = False,
return_direct: bool = False,
exceptions_to_reflect: Optional[List[Exception]] = None,
embeddings: Optional[BaseEmbedding] = None,
):
"""
Args:
Expand All @@ -34,7 +36,7 @@ def __init__(
return_strings_only: Whether to return only the text of the retrieved documents.
"""
tool = make_retriever_langchain_tool(
data_dir, persist_dir, return_strings_only=return_strings_only
data_dir, persist_dir, return_strings_only=return_strings_only, embeddings=embeddings
)
super().__init__(
tool=tool, return_direct=return_direct, exceptions_to_reflect=exceptions_to_reflect
Expand All @@ -49,9 +51,15 @@ class RetrieverToolInput(BaseModel, arbitrary_types_allowed=True):
)


def make_retriever_langchain_tool(data_dir, persist_dir, return_strings_only: bool = False):
text_embedding_model = "text-embedding-ada-002"
embeddings = OpenAIEmbedding(model=text_embedding_model)
def make_retriever_langchain_tool(
data_dir,
persist_dir,
return_strings_only: bool = False,
embeddings: Optional[BaseEmbedding] = None,
):
if embeddings is None:
text_embedding_model = "text-embedding-ada-002"
embeddings = OpenAIEmbedding(model=text_embedding_model)

if not os.path.exists(persist_dir):
# load the documents and create the index
Expand Down

0 comments on commit 260ff53

Please sign in to comment.