From 8e74ac55a16ca012b52affbc754e4b04130e65db Mon Sep 17 00:00:00 2001 From: Lorenzo Paleari <100212108+LorenzoPaleari@users.noreply.github.com> Date: Mon, 2 Sep 2024 17:03:14 +0200 Subject: [PATCH] fix: correctly parsing output when using structured_output --- .../nodes/generate_answer_csv_node.py | 25 +++++++++++++++---- scrapegraphai/nodes/generate_answer_node.py | 15 +++++++---- .../nodes/generate_answer_omni_node.py | 22 ++++++++++++++-- .../nodes/generate_answer_pdf_node.py | 23 ++++++++++++++--- scrapegraphai/nodes/merge_answers_node.py | 22 +++++++++++++--- 5 files changed, 89 insertions(+), 18 deletions(-) diff --git a/scrapegraphai/nodes/generate_answer_csv_node.py b/scrapegraphai/nodes/generate_answer_csv_node.py index 0907dfb9..de127f47 100644 --- a/scrapegraphai/nodes/generate_answer_csv_node.py +++ b/scrapegraphai/nodes/generate_answer_csv_node.py @@ -6,11 +6,13 @@ from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel +from langchain_core.utils.pydantic import is_basemodel_subclass +from langchain_openai import ChatOpenAI +from langchain_mistralai import ChatMistralAI from tqdm import tqdm from ..utils.logging import get_logger from .base_node import BaseNode -from ..prompts.generate_answer_node_csv_prompts import (TEMPLATE_CHUKS_CSV, - TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV) +from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV class GenerateAnswerCSVNode(BaseNode): """ @@ -92,9 +94,24 @@ def execute(self, state): # Initialize the output parser if self.node_config.get("schema", None) is not None: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + + if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): + self.llm_model = self.llm_model.with_structured_output( + schema = self.node_config["schema"], + method="function_calling") # json schema works only on specific models + + # default parser to empty lambda function + output_parser = lambda x: x + if is_basemodel_subclass(self.node_config["schema"]): + output_parser = dict + format_instructions = "NA" + else: + output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + format_instructions = output_parser.get_format_instructions() + else: output_parser = JsonOutputParser() + format_instructions = output_parser.get_format_instructions() TEMPLATE_NO_CHUKS_CSV_PROMPT = TEMPLATE_NO_CHUKS_CSV TEMPLATE_CHUKS_CSV_PROMPT = TEMPLATE_CHUKS_CSV @@ -105,8 +122,6 @@ def execute(self, state): TEMPLATE_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_CHUKS_CSV TEMPLATE_MERGE_CSV_PROMPT = self.additional_info + TEMPLATE_MERGE_CSV - format_instructions = output_parser.get_format_instructions() - chains_dict = {} if len(doc) == 1: diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index a1c8ff22..f5b6b5c8 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -1,16 +1,15 @@ """ GenerateAnswerNode Module """ -from sys import modules from typing import List, Optional from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel +from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI, AzureChatOpenAI from langchain_mistralai import ChatMistralAI from langchain_community.chat_models import ChatOllama from tqdm import tqdm -from ..utils.logging import get_logger from .base_node import BaseNode from ..prompts import TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD @@ -91,14 +90,20 @@ def execute(self, state: dict) -> dict: if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"], - method="json_schema") + method="function_calling") # json schema works only on specific models + + # default parser to empty lambda function + output_parser = lambda x: x + if is_basemodel_subclass(self.node_config["schema"]): + output_parser = dict + format_instructions = "NA" else: output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + format_instructions = output_parser.get_format_instructions() else: output_parser = JsonOutputParser() - - format_instructions = output_parser.get_format_instructions() + format_instructions = output_parser.get_format_instructions() if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper: template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD diff --git a/scrapegraphai/nodes/generate_answer_omni_node.py b/scrapegraphai/nodes/generate_answer_omni_node.py index 34ee3e87..aabebce4 100644 --- a/scrapegraphai/nodes/generate_answer_omni_node.py +++ b/scrapegraphai/nodes/generate_answer_omni_node.py @@ -5,6 +5,9 @@ from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel +from langchain_core.utils.pydantic import is_basemodel_subclass +from langchain_openai import ChatOpenAI +from langchain_mistralai import ChatMistralAI from tqdm import tqdm from langchain_community.chat_models import ChatOllama from .base_node import BaseNode @@ -78,9 +81,25 @@ def execute(self, state: dict) -> dict: # Initialize the output parser if self.node_config.get("schema", None) is not None: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + + if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): + self.llm_model = self.llm_model.with_structured_output( + schema = self.node_config["schema"], + method="function_calling") # json schema works only on specific models + + # default parser to empty lambda function + output_parser = lambda x: x + if is_basemodel_subclass(self.node_config["schema"]): + output_parser = dict + format_instructions = "NA" + else: + output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + format_instructions = output_parser.get_format_instructions() + else: output_parser = JsonOutputParser() + format_instructions = output_parser.get_format_instructions() + TEMPLATE_NO_CHUNKS_OMNI_prompt = TEMPLATE_NO_CHUNKS_OMNI TEMPLATE_CHUNKS_OMNI_prompt = TEMPLATE_CHUNKS_OMNI TEMPLATE_MERGE_OMNI_prompt= TEMPLATE_MERGE_OMNI @@ -90,7 +109,6 @@ def execute(self, state: dict) -> dict: TEMPLATE_CHUNKS_OMNI_prompt = self.additional_info + TEMPLATE_CHUNKS_OMNI_prompt TEMPLATE_MERGE_OMNI_prompt = self.additional_info + TEMPLATE_MERGE_OMNI_prompt - format_instructions = output_parser.get_format_instructions() chains_dict = {} diff --git a/scrapegraphai/nodes/generate_answer_pdf_node.py b/scrapegraphai/nodes/generate_answer_pdf_node.py index f3e68eab..5e1e2687 100644 --- a/scrapegraphai/nodes/generate_answer_pdf_node.py +++ b/scrapegraphai/nodes/generate_answer_pdf_node.py @@ -5,6 +5,9 @@ from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel +from langchain_core.utils.pydantic import is_basemodel_subclass +from langchain_openai import ChatOpenAI +from langchain_mistralai import ChatMistralAI from tqdm import tqdm from langchain_community.chat_models import ChatOllama from ..utils.logging import get_logger @@ -93,9 +96,25 @@ def execute(self, state): # Initialize the output parser if self.node_config.get("schema", None) is not None: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + + if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): + self.llm_model = self.llm_model.with_structured_output( + schema = self.node_config["schema"], + method="function_calling") # json schema works only on specific models + + # default parser to empty lambda function + output_parser = lambda x: x + if is_basemodel_subclass(self.node_config["schema"]): + output_parser = dict + format_instructions = "NA" + else: + output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + format_instructions = output_parser.get_format_instructions() + else: output_parser = JsonOutputParser() + format_instructions = output_parser.get_format_instructions() + TEMPLATE_NO_CHUNKS_PDF_prompt = TEMPLATE_NO_CHUNKS_PDF TEMPLATE_CHUNKS_PDF_prompt = TEMPLATE_CHUNKS_PDF TEMPLATE_MERGE_PDF_prompt = TEMPLATE_MERGE_PDF @@ -105,8 +124,6 @@ def execute(self, state): TEMPLATE_CHUNKS_PDF_prompt = self.additional_info + TEMPLATE_CHUNKS_PDF_prompt TEMPLATE_MERGE_PDF_prompt = self.additional_info + TEMPLATE_MERGE_PDF_prompt - format_instructions = output_parser.get_format_instructions() - if len(doc) == 1: prompt = PromptTemplate( template=TEMPLATE_NO_CHUNKS_PDF_prompt, diff --git a/scrapegraphai/nodes/merge_answers_node.py b/scrapegraphai/nodes/merge_answers_node.py index f2559a09..1e25cccb 100644 --- a/scrapegraphai/nodes/merge_answers_node.py +++ b/scrapegraphai/nodes/merge_answers_node.py @@ -4,6 +4,9 @@ from typing import List, Optional from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser +from langchain_core.utils.pydantic import is_basemodel_subclass +from langchain_openai import ChatOpenAI +from langchain_mistralai import ChatMistralAI from ..utils.logging import get_logger from .base_node import BaseNode from ..prompts import TEMPLATE_COMBINED @@ -68,11 +71,24 @@ def execute(self, state: dict) -> dict: answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n" if self.node_config.get("schema", None) is not None: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + + if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): + self.llm_model = self.llm_model.with_structured_output( + schema = self.node_config["schema"], + method="function_calling") # json schema works only on specific models + + # default parser to empty lambda function + output_parser = lambda x: x + if is_basemodel_subclass(self.node_config["schema"]): + output_parser = dict + format_instructions = "NA" + else: + output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + format_instructions = output_parser.get_format_instructions() + else: output_parser = JsonOutputParser() - - format_instructions = output_parser.get_format_instructions() + format_instructions = output_parser.get_format_instructions() prompt_template = PromptTemplate( template=TEMPLATE_COMBINED,