Skip to content

Commit

Permalink
Merge pull request #155 from VinciGit00/graphs-iterator-node
Browse files Browse the repository at this point in the history
GraphIteratorNode and MergeAnswersNode
  • Loading branch information
VinciGit00 authored May 6, 2024
2 parents 8c5397f + d9a4ab2 commit e6387d7
Show file tree
Hide file tree
Showing 23 changed files with 384 additions and 75 deletions.
31 changes: 25 additions & 6 deletions examples/openai/custom_graph_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import os
from dotenv import load_dotenv

from langchain_openai import OpenAIEmbeddings
from scrapegraphai.models import OpenAI
from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode, RobotsNode
Expand All @@ -20,7 +22,7 @@
"api_key": openai_key,
"model": "gpt-3.5-turbo",
"temperature": 0,
"streaming": True
"streaming": False
},
}

Expand All @@ -29,33 +31,50 @@
# ************************************************

llm_model = OpenAI(graph_config["llm"])
embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)

# define the nodes for the graph
robot_node = RobotsNode(
input="url",
output=["is_scrapable"],
node_config={"llm_model": llm_model}
node_config={
"llm_model": llm_model,
"verbose": True,
}
)

fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
node_config={"headless": True, "verbose": True}
node_config={
"verbose": True,
"headless": True,
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={"chunk_size": 4096}
node_config={
"chunk_size": 4096,
"verbose": True,
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={"llm_model": llm_model},
node_config={
"llm_model": llm_model,
"embedder_model": embedder,
"verbose": True,
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={"llm_model": llm_model},
node_config={
"llm_model": llm_model,
"verbose": True,
}
)

# ************************************************
Expand Down
98 changes: 98 additions & 0 deletions examples/openai/search_graph_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Example of custom graph using existing nodes
"""

import os
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from scrapegraphai.models import OpenAI
from scrapegraphai.graphs import BaseGraph, SmartScraperGraph
from scrapegraphai.nodes import SearchInternetNode, GraphIteratorNode, MergeAnswersNode
load_dotenv()

# ************************************************
# Define the configuration for the graph
# ************************************************

openai_key = os.getenv("OPENAI_APIKEY")

graph_config = {
"llm": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
},
}

# ************************************************
# Create a SmartScraperGraph instance
# ************************************************

smart_scraper_graph = SmartScraperGraph(
prompt="",
source="",
config=graph_config
)

# ************************************************
# Define the graph nodes
# ************************************************

llm_model = OpenAI(graph_config["llm"])
embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)

search_internet_node = SearchInternetNode(
input="user_prompt",
output=["urls"],
node_config={
"llm_model": llm_model,
"max_results": 5, # num of search results to fetch
"verbose": True,
}
)

graph_iterator_node = GraphIteratorNode(
input="user_prompt & urls",
output=["results"],
node_config={
"graph_instance": smart_scraper_graph,
"verbose": True,
}
)

merge_answers_node = MergeAnswersNode(
input="user_prompt & results",
output=["answer"],
node_config={
"llm_model": llm_model,
"verbose": True,
}
)

# ************************************************
# Create the graph by defining the connections
# ************************************************

graph = BaseGraph(
nodes=[
search_internet_node,
graph_iterator_node,
merge_answers_node
],
edges=[
(search_internet_node, graph_iterator_node),
(graph_iterator_node, merge_answers_node)
],
entry_point=search_internet_node
)

# ************************************************
# Execute the graph
# ************************************************

result, execution_info = graph.execute({
"user_prompt": "List me all the typical Chioggia dishes."
})

# get the answer from the result
result = result.get("answer", "No answer found.")
print(result)
4 changes: 3 additions & 1 deletion examples/openai/search_graph_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
"api_key": openai_key,
"model": "gpt-3.5-turbo",
},
"max_results": 5,
"verbose": True,
}

# ************************************************
# Create the SearchGraph instance and run it
# ************************************************

search_graph = SearchGraph(
prompt="List me top 5 eyeliner products for a gift.",
prompt="List me the best escursions near Trento",
config=graph_config
)

Expand Down
2 changes: 1 addition & 1 deletion examples/openai/smart_scraper_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"api_key": openai_key,
"model": "gpt-3.5-turbo",
},
"verbose": True,
"verbose": False,
}

# ************************************************
Expand Down
1 change: 1 addition & 0 deletions scrapegraphai/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__init__.py file for graphs folder
"""

from .abstract_graph import AbstractGraph
from .base_graph import BaseGraph
from .smart_scraper_graph import SmartScraperGraph
from .speech_graph import SpeechGraph
Expand Down
2 changes: 1 addition & 1 deletion scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.execution_info = None

# Set common configuration parameters
self.verbose = True if config is None else config.get("verbose", False)
self.verbose = False if config is None else config.get("verbose", False)
self.headless = True if config is None else config.get(
"headless", True)
common_params = {"headless": self.headless,
Expand Down
70 changes: 37 additions & 33 deletions scrapegraphai/graphs/search_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from .base_graph import BaseGraph
from ..nodes import (
SearchInternetNode,
FetchNode,
ParseNode,
RAGNode,
GenerateAnswerNode
GraphIteratorNode,
MergeAnswersNode
)
from .abstract_graph import AbstractGraph
from .smart_scraper_graph import SmartScraperGraph


class SearchGraph(AbstractGraph):
Expand Down Expand Up @@ -38,6 +37,11 @@ class SearchGraph(AbstractGraph):
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, config: dict):

self.max_results = config.get("max_results", 3)
super().__init__(prompt, config)

def _create_graph(self) -> BaseGraph:
"""
Creates the graph of nodes representing the workflow for web scraping and searching.
Expand All @@ -46,53 +50,53 @@ def _create_graph(self) -> BaseGraph:
BaseGraph: A graph instance representing the web scraping and searching workflow.
"""

# ************************************************
# Create a SmartScraperGraph instance
# ************************************************

smart_scraper_instance = SmartScraperGraph(
prompt="",
source="",
config=self.config
)

# ************************************************
# Define the graph nodes
# ************************************************

search_internet_node = SearchInternetNode(
input="user_prompt",
output=["url"],
node_config={
"llm_model": self.llm_model
}
)
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"]
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
output=["urls"],
node_config={
"chunk_size": self.model_token
"llm_model": self.llm_model,
"max_results": self.max_results
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
graph_iterator_node = GraphIteratorNode(
input="user_prompt & urls",
output=["results"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
"graph_instance": smart_scraper_instance,
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",

merge_answers_node = MergeAnswersNode(
input="user_prompt & results",
output=["answer"],
node_config={
"llm_model": self.llm_model
"llm_model": self.llm_model,
}
)

return BaseGraph(
nodes=[
search_internet_node,
fetch_node,
parse_node,
rag_node,
generate_answer_node,
graph_iterator_node,
merge_answers_node
],
edges=[
(search_internet_node, fetch_node),
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
(search_internet_node, graph_iterator_node),
(graph_iterator_node, merge_answers_node)
],
entry_point=search_internet_node
)
Expand Down
2 changes: 2 additions & 0 deletions scrapegraphai/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
from .robots_node import RobotsNode
from .generate_answer_csv_node import GenerateAnswerCSVNode
from .generate_answer_pdf_node import GenerateAnswerPDFNode
from .graph_iterator_node import GraphIteratorNode
from .merge_answers_node import MergeAnswersNode
4 changes: 2 additions & 2 deletions scrapegraphai/nodes/generate_answer_csv_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Module for generating the answer node
"""
# Imports from standard library
from typing import List
from typing import List, Optional
from tqdm import tqdm

# Imports from Langchain
Expand Down Expand Up @@ -39,7 +39,7 @@ class GenerateAnswerCSVNode(BaseNode):
updating the state with the generated answer under the 'answer' key.
"""

def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
node_name: str = "GenerateAnswer"):
"""
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
Expand Down
4 changes: 2 additions & 2 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

# Imports from standard library
from typing import List
from typing import List, Optional
from tqdm import tqdm

# Imports from Langchain
Expand Down Expand Up @@ -33,7 +33,7 @@ class GenerateAnswerNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
"""

def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
node_name: str = "GenerateAnswer"):
super().__init__(node_name, "node", input, output, 2, node_config)

Expand Down
4 changes: 2 additions & 2 deletions scrapegraphai/nodes/generate_answer_node_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Module for generating the answer node
"""
# Imports from standard library
from typing import List
from typing import List, Optional
from tqdm import tqdm

# Imports from Langchain
Expand Down Expand Up @@ -39,7 +39,7 @@ class GenerateAnswerCSVNode(BaseNode):
updating the state with the generated answer under the 'answer' key.
"""

def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
node_name: str = "GenerateAnswer"):
"""
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
Expand Down
Loading

0 comments on commit e6387d7

Please sign in to comment.