Skip to content

Commit

Permalink
Merge pull request #626 from LorenzoPaleari/598-fix-pydantic-errors
Browse files Browse the repository at this point in the history
Fixed pydantic errors when using `with_strctured_output`
  • Loading branch information
VinciGit00 authored Sep 2, 2024
2 parents 5e99071 + 8e74ac5 commit 8442700
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 18 deletions.
25 changes: 20 additions & 5 deletions scrapegraphai/nodes/generate_answer_csv_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from tqdm import tqdm
from ..utils.logging import get_logger
from .base_node import BaseNode
from ..prompts.generate_answer_node_csv_prompts import (TEMPLATE_CHUKS_CSV,
TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV)
from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV

class GenerateAnswerCSVNode(BaseNode):
"""
Expand Down Expand Up @@ -92,9 +94,24 @@ def execute(self, state):

# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="function_calling") # json schema works only on specific models

# default parser to empty lambda function
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()

else:
output_parser = JsonOutputParser()
format_instructions = output_parser.get_format_instructions()

TEMPLATE_NO_CHUKS_CSV_PROMPT = TEMPLATE_NO_CHUKS_CSV
TEMPLATE_CHUKS_CSV_PROMPT = TEMPLATE_CHUKS_CSV
Expand All @@ -105,8 +122,6 @@ def execute(self, state):
TEMPLATE_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_CHUKS_CSV
TEMPLATE_MERGE_CSV_PROMPT = self.additional_info + TEMPLATE_MERGE_CSV

format_instructions = output_parser.get_format_instructions()

chains_dict = {}

if len(doc) == 1:
Expand Down
15 changes: 10 additions & 5 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""
GenerateAnswerNode Module
"""
from sys import modules
from typing import List, Optional
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_mistralai import ChatMistralAI
from langchain_community.chat_models import ChatOllama
from tqdm import tqdm
from ..utils.logging import get_logger
from .base_node import BaseNode
from ..prompts import TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD

Expand Down Expand Up @@ -91,14 +90,20 @@ def execute(self, state: dict) -> dict:
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="json_schema")
method="function_calling") # json schema works only on specific models

# default parser to empty lambda function
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()

else:
output_parser = JsonOutputParser()

format_instructions = output_parser.get_format_instructions()
format_instructions = output_parser.get_format_instructions()

if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
Expand Down
22 changes: 20 additions & 2 deletions scrapegraphai/nodes/generate_answer_omni_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from tqdm import tqdm
from langchain_community.chat_models import ChatOllama
from .base_node import BaseNode
Expand Down Expand Up @@ -78,9 +81,25 @@ def execute(self, state: dict) -> dict:

# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="function_calling") # json schema works only on specific models

# default parser to empty lambda function
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()

else:
output_parser = JsonOutputParser()
format_instructions = output_parser.get_format_instructions()

TEMPLATE_NO_CHUNKS_OMNI_prompt = TEMPLATE_NO_CHUNKS_OMNI
TEMPLATE_CHUNKS_OMNI_prompt = TEMPLATE_CHUNKS_OMNI
TEMPLATE_MERGE_OMNI_prompt= TEMPLATE_MERGE_OMNI
Expand All @@ -90,7 +109,6 @@ def execute(self, state: dict) -> dict:
TEMPLATE_CHUNKS_OMNI_prompt = self.additional_info + TEMPLATE_CHUNKS_OMNI_prompt
TEMPLATE_MERGE_OMNI_prompt = self.additional_info + TEMPLATE_MERGE_OMNI_prompt

format_instructions = output_parser.get_format_instructions()


chains_dict = {}
Expand Down
23 changes: 20 additions & 3 deletions scrapegraphai/nodes/generate_answer_pdf_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from tqdm import tqdm
from langchain_community.chat_models import ChatOllama
from ..utils.logging import get_logger
Expand Down Expand Up @@ -93,9 +96,25 @@ def execute(self, state):

# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="function_calling") # json schema works only on specific models

# default parser to empty lambda function
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()

else:
output_parser = JsonOutputParser()
format_instructions = output_parser.get_format_instructions()

TEMPLATE_NO_CHUNKS_PDF_prompt = TEMPLATE_NO_CHUNKS_PDF
TEMPLATE_CHUNKS_PDF_prompt = TEMPLATE_CHUNKS_PDF
TEMPLATE_MERGE_PDF_prompt = TEMPLATE_MERGE_PDF
Expand All @@ -105,8 +124,6 @@ def execute(self, state):
TEMPLATE_CHUNKS_PDF_prompt = self.additional_info + TEMPLATE_CHUNKS_PDF_prompt
TEMPLATE_MERGE_PDF_prompt = self.additional_info + TEMPLATE_MERGE_PDF_prompt

format_instructions = output_parser.get_format_instructions()

if len(doc) == 1:
prompt = PromptTemplate(
template=TEMPLATE_NO_CHUNKS_PDF_prompt,
Expand Down
22 changes: 19 additions & 3 deletions scrapegraphai/nodes/merge_answers_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from typing import List, Optional
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from ..utils.logging import get_logger
from .base_node import BaseNode
from ..prompts import TEMPLATE_COMBINED
Expand Down Expand Up @@ -68,11 +71,24 @@ def execute(self, state: dict) -> dict:
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"

if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="function_calling") # json schema works only on specific models

# default parser to empty lambda function
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()

else:
output_parser = JsonOutputParser()

format_instructions = output_parser.get_format_instructions()
format_instructions = output_parser.get_format_instructions()

prompt_template = PromptTemplate(
template=TEMPLATE_COMBINED,
Expand Down

0 comments on commit 8442700

Please sign in to comment.