-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from ShoggothAI/research-agent
Research agent
- Loading branch information
Showing
22 changed files
with
1,647 additions
and
278 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from typing import Union, Sequence, List | ||
|
||
from dotenv import load_dotenv | ||
|
||
from llama_index.graph_stores.kuzu import KuzuGraphStore | ||
from langchain.schema import AIMessage, HumanMessage, SystemMessage, BaseMessage | ||
from langchain_core.prompts.chat import ChatPromptTemplate | ||
from motleycrew.agent.langchain.react import ReactMotleyAgent | ||
|
||
from motleycrew.tool.llm_tool import LLMTool | ||
from motleycrew import MotleyCrew, Task | ||
|
||
from .blog_post_input import text | ||
|
||
load_dotenv() | ||
|
||
# TODO: switch example to using URL instead of fixed text? | ||
# from langchain.document_loaders import UnstructuredURLLoader | ||
# from langchain.text_splitter import TokenTextSplitter | ||
# def urls_to_messages(urls: Union[str, Sequence[str]]) -> List[HumanMessage]: | ||
# if isinstance(urls, str): | ||
# urls = [urls] | ||
# loader = UnstructuredURLLoader(urls=urls) | ||
# data = loader.load() | ||
# text_splitter = TokenTextSplitter(chunk_size=1000, chunk_overlap=0) | ||
# texts = text_splitter.split_documents(data) | ||
# return [HumanMessage(content=d.page_content) for d in texts] | ||
|
||
|
||
max_words = 500 | ||
min_words = 450 | ||
|
||
editor_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
SystemMessage( | ||
content="You are an experienced online blog post editor with 10 years of experience." | ||
), | ||
HumanMessage( | ||
content="""Review the blog post draft below (delimited by triple backticks) | ||
and provide a critique and use specific examples from the text on what | ||
should be done to improve the draft, with data professionals as the intended audience. | ||
Also, suggest a catchy title for the story. | ||
```{input}``` | ||
""" | ||
), | ||
] | ||
) | ||
|
||
illustrator_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
SystemMessage( | ||
content="You are a professional illustrator with 10 years of experience." | ||
), | ||
HumanMessage( | ||
content="You are given the following draft story, delimited by triple back quotes: ```{second_draft}```" | ||
), | ||
HumanMessage( | ||
content="""Your task is to specify the illustrations that would fit this story. | ||
Make sure the illustrations are varied in style, eye-catching, and some of them humorous. | ||
Describe each illustration in a way suitable for entering in a Midjourney prompt. | ||
Each description should be detailed and verbose. Don't explain the purpose of the illustrations, | ||
just describe in great | ||
detail what each illustration should show, in a way suitable for a generative image prompt. | ||
There should be at most 5 and at least 3 illustrations. | ||
Return the illustration descriptions as a list in the format | ||
["...", "...", ..., "..."] | ||
""" | ||
), | ||
] | ||
) | ||
|
||
seo_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
SystemMessage( | ||
content="""Act as an SEO expert with 10 years of experience but ensure to | ||
explain any SEO jargon for clarity when using it.""" | ||
), | ||
HumanMessage( | ||
content="""Review the blog post below (delimited by triple back quotes) and provide specific | ||
examples from the text where to optimize its SEO content. | ||
Recommend SEO-friendly titles and subtitles that could be used. | ||
```{second_draft}``` | ||
""" | ||
), | ||
] | ||
) | ||
|
||
editor = LLMTool( | ||
name="editor", | ||
description="An editor providing constructive suggestions to improve the blog post submitted to it", | ||
prompt=editor_prompt, | ||
) | ||
|
||
# TODO: Turn it into an agent that calls the DALL-E tool | ||
# and returns a dict {image_description: image_url} | ||
illustrator = LLMTool( | ||
name="illustrator", | ||
description="An illustrator providing detailed descriptions of illustrations for a story", | ||
prompt=illustrator_prompt, | ||
) | ||
|
||
seo_expert = LLMTool( | ||
name="seo_expert", | ||
description="An SEO expert providing SEO optimization suggestions", | ||
prompt=seo_prompt, | ||
) | ||
|
||
|
||
writer = ReactMotleyAgent( | ||
prompt="You are a professional freelance copywriter with 10 years of experience.", | ||
tools=[editor, illustrator, seo_expert], | ||
) | ||
|
||
# Create tasks for your agents | ||
crew = MotleyCrew() | ||
task1 = Task( | ||
crew=crew, | ||
name="Write a blog post from the provided information", | ||
description=f"""Write a blog post of at most {max_words} words and at least {min_words} | ||
words based on the information provided. Keep the tone suitable for an audience of | ||
data professionals, avoid superlatives and an overly excitable tone. | ||
Don't discuss installation or testing. | ||
The summary will be provided in one or multiple chunks, followed by <END>. | ||
Proceed as follows: first, write a draft blog post as described above. | ||
Then, submit it in turn to the editor, illustrator, and SEO expert for feedback. | ||
In the case of the illustrator, insert the illustration descriptions it provides in | ||
square brackets into the appropriate places in the draft. | ||
In each case, revise the draft as per the response of the expert and submit it to the next expert. | ||
After you have implemented each expert's recommendations, return the final draft in markdown format. | ||
Return the blog post in markdown format. | ||
Information begins: {text} <END>""", | ||
agent=writer, | ||
) | ||
|
||
crew.run(verbose=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
text = """ | ||
Wise Pizza: A library for automated figuring out most unusual segments | ||
WisePizza is a library to find and visualise the most interesting slices in multidimensional data based on Lasso and LP solvers, which provides different functions to find segments whose average is most different from the global one or find segments most useful in explaining the difference between two datasets. | ||
The approach | ||
WisePizza assumes you have a dataset with a number of discrete dimensions (could be currency, region, etc). For each combination of dimensions, the dataset must have a total value (total of the metric over that segment, for example the total volume in that region and currency), and an optional size value (set to 1 if not specified), this could for example be the total number of customers for that region and currency. The average value of the outcome for the segment is defined as total divided by size, in this example it would be the average volume per customer. | ||
explain_levels takes such a dataset and looks for a small number of 'simple' segments (each only constraining a small number of dimensions) that between them explain most of the variation in the averages; you could also think of them as the segments whose size-weighted deviation from the overall dataset average is the largest. This trades off unusual averages (which will naturally occur more for smaller segments) against segment size. | ||
Yet another way of looking at it is that we look for segments which, if their average was reset to the overall dataset average, would move overall total the most. | ||
explain_changes_in_totals and explain_changes_in_average take two datasets of the kind described above, with the same column names, and apply the same kind of logic to find the segments that contribute most to the difference (in total or average, respectively) between the two datasets, optionally splitting that into contributions from changes in segment size and changes in segment total. | ||
Sometimes, rather than explaining the change in totals from one period to the next, one wishes to explain a change in averages. The analytics of this are a little different - for example, while (as long as all weights and totals are positive) increasing a segment size (other things remaining equal) always increases the overall total, it can increase or decrease the pverall average, depending on whether the average value of that segment is below or above the overall average. | ||
Table of Contents | ||
What can this do for you? | ||
Find interesting slices | ||
Comparison between two datasets | ||
Installation | ||
Quick Start | ||
For Developers | ||
Tests | ||
What can this do for you? | ||
The automated search for interesting segments can give you the following: | ||
1. Better information about segments and subsegments in your data | ||
By using WisePizza and defining initial segments, you can find a segment which maximizes a specific outcome, such as adoption rates. | ||
2. Understanding differences in two time periods or two dataframes | ||
If you have two time periods or two datasets, you can find segments that experience the largest change in the totals from previous period/dataset. | ||
Installation | ||
You can always get the newest wise_pizza release using pip: https://pypi.org/project/wise-pizza/ | ||
pip install wise-pizza | ||
From the command line (another way): | ||
pip install git+https://github.com/transferwise/wise-pizza.git | ||
From Jupyter notebook (another way): | ||
!pip install git+https://github.com/transferwise/wise-pizza.git | ||
Or you can clone and run from source, in which case you should pip -r requirements.txt before running. | ||
Quick Start | ||
The wisepizza package can be used for finding segments with unusual average: | ||
sf = explain_levels( | ||
df=data, | ||
dims=dims, | ||
total_name=totals, | ||
size_name=size, | ||
max_depth=2, | ||
min_segments=20, | ||
solver="lasso" | ||
) | ||
plot | ||
Or for finding changes between two datasets in totals: | ||
sf1 = explain_changes_in_totals( | ||
df1=pre_data, | ||
df2=data, | ||
dims=dims, | ||
total_name=totals, | ||
size_name=size, | ||
max_depth=2, | ||
min_segments=20, | ||
how="totals", | ||
solver="lasso" | ||
) | ||
plot | ||
Or for finding changes between two datasets in average: | ||
sf1 = explain_changes_in_average( | ||
df1=pre_data, | ||
df2=data, | ||
dims=dims, | ||
total_name=totals, | ||
size_name=size, | ||
max_depth=2, | ||
min_segments=20, | ||
how="totals", | ||
solver="lasso" | ||
) | ||
plot | ||
And then you can visualize differences: | ||
sf.plot() | ||
And check segments: | ||
sf.segments | ||
Please see the full example here | ||
For Developers | ||
Testing | ||
We use PyTest for testing. If you want to contribute code, make sure that the tests in tests/ run without errors. | ||
Wise-pizza is open sourced and maintained by Wise Plc. Copyright 2023 Wise Plc. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import logging | ||
import kuzu | ||
|
||
from motleycrew.storage import MotleyGraphStore | ||
|
||
from question_struct import Question | ||
from question_answerer import AnswerSubQuestionTool | ||
|
||
|
||
class AnswerOrchestrator: | ||
def __init__(self, storage: MotleyGraphStore, answer_length: int): | ||
self.storage = storage | ||
self.question_answering_tool = AnswerSubQuestionTool(graph=self.storage, answer_length=answer_length) | ||
|
||
def get_unanswered_available_questions(self) -> list[Question]: | ||
query = ( | ||
"MATCH (n1:{}) " | ||
"WHERE n1.answer IS NULL AND n1.context IS NOT NULL " | ||
"AND NOT EXISTS {{MATCH (n1)-[]->(n2:{}) " | ||
"WHERE n2.answer IS NULL AND n2.context IS NOT NULL}} " | ||
"RETURN n1" | ||
).format(self.storage.node_table_name, self.storage.node_table_name) | ||
|
||
query_result = self.storage.run_cypher_query(query) | ||
return [Question.deserialize(row[0]) for row in query_result] | ||
|
||
def __call__(self) -> Question | None: | ||
last_question = None | ||
|
||
while True: | ||
questions = self.get_unanswered_available_questions() | ||
logging.info("Available questions: %s", questions) | ||
|
||
if not len(questions): | ||
logging.info("All questions answered!") | ||
break | ||
else: | ||
last_question = questions[0] | ||
logging.info("Running answerer for question %s", last_question) | ||
self.question_answering_tool.invoke({"question": last_question}) | ||
|
||
if not last_question: | ||
logging.warning("Nothing to answer!") | ||
return | ||
|
||
return Question.deserialize(self.storage.get_entity(last_question.id)) | ||
|
||
|
||
if __name__ == "__main__": | ||
from pathlib import Path | ||
from dotenv import load_dotenv | ||
from motleycrew.storage import MotleyKuzuGraphStore | ||
from motleycrew.common.utils import configure_logging | ||
|
||
load_dotenv() | ||
configure_logging(verbose=True) | ||
|
||
here = Path(__file__).parent | ||
db_path = here / "research_db" | ||
|
||
db = kuzu.Database(db_path) | ||
storage = MotleyKuzuGraphStore( | ||
db, node_table_schema={"question": "STRING", "answer": "STRING", "context": "STRING"} | ||
) | ||
|
||
orchestrator = AnswerOrchestrator(storage=storage, answer_length=30) | ||
result = orchestrator() | ||
print(result) |
Oops, something went wrong.