Skip to content

Commit

Permalink
Merge pull request #562 from ScrapeGraphAI/support_structured_output_…
Browse files Browse the repository at this point in the history
…shema_openai

Support structured output shema openai
  • Loading branch information
VinciGit00 authored Aug 19, 2024
2 parents 6a08cc8 + 7d2fc67 commit d1f6b9f
Show file tree
Hide file tree
Showing 13 changed files with 44 additions and 29 deletions.
7 changes: 3 additions & 4 deletions examples/anthropic/search_graph_schema_haiku.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
"""

import os
from typing import List
from dotenv import load_dotenv
load_dotenv()

from pydantic import BaseModel, Field
from scrapegraphai.graphs import SearchGraph

from pydantic import BaseModel, Field
from typing import List
load_dotenv()

# ************************************************
# Define the output schema for the graph
Expand Down
3 changes: 2 additions & 1 deletion examples/azure/smart_scraper_schema_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Basic example of scraping pipeline using SmartScraper with schema
"""

import os, json
import os
import json
from typing import List
from pydantic import BaseModel, Field
from dotenv import load_dotenv
Expand Down
2 changes: 1 addition & 1 deletion examples/local_models/smart_scraper_schema_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Projects(BaseModel):

graph_config = {
"llm": {
"model": "ollama/llama3",
"model": "ollama/llama3.1",
"temperature": 0,
"format": "json", # Ollama needs the format to be specified explicitly
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
Expand Down
3 changes: 2 additions & 1 deletion examples/openai/smart_scraper_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Basic example of scraping pipeline using SmartScraper
"""

import os, json
import os
import json
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info
from dotenv import load_dotenv
Expand Down
2 changes: 1 addition & 1 deletion examples/openai/smart_scraper_schema_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Projects(BaseModel):
graph_config = {
"llm": {
"api_key":openai_key,
"model": "gpt-4o",
"model": "gpt-4o-mini",
},
"verbose": True,
"headless": False,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ authors = [
]

dependencies = [
"langchain>=0.2.10",
"langchain>=0.2.14",
"langchain-fireworks>=0.1.3",
"langchain_community>=0.2.9",
"langchain-google-genai>=1.0.7",
"langchain-google-vertexai>=1.0.7",
"langchain-openai>=0.1.17",
"langchain-openai>=0.1.22",
"langchain-groq>=0.1.3",
"langchain-aws>=0.1.3",
"langchain-anthropic>=0.1.11",
Expand Down
6 changes: 3 additions & 3 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ jsonschema-specifications==2023.12.1
# via jsonschema
kiwisolver==1.4.5
# via matplotlib
langchain==0.2.12
langchain==0.2.14
# via langchain-community
# via scrapegraphai
langchain-anthropic==0.1.22
Expand All @@ -264,7 +264,7 @@ langchain-aws==0.1.16
# via scrapegraphai
langchain-community==0.2.11
# via scrapegraphai
langchain-core==0.2.29
langchain-core==0.2.33
# via langchain
# via langchain-anthropic
# via langchain-aws
Expand Down Expand Up @@ -292,7 +292,7 @@ langchain-mistralai==0.1.12
# via scrapegraphai
langchain-nvidia-ai-endpoints==0.2.1
# via scrapegraphai
langchain-openai==0.1.21
langchain-openai==0.1.22
# via scrapegraphai
langchain-text-splitters==0.2.2
# via langchain
Expand Down
9 changes: 5 additions & 4 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ jinja2==3.1.4
# via torch
jiter==0.5.0
# via anthropic
# via openai
jmespath==1.0.1
# via boto3
# via botocore
Expand All @@ -187,7 +188,7 @@ jsonpatch==1.33
# via langchain-core
jsonpointer==3.0.0
# via jsonpatch
langchain==0.2.11
langchain==0.2.14
# via langchain-community
# via scrapegraphai
langchain-anthropic==0.1.20
Expand All @@ -196,7 +197,7 @@ langchain-aws==0.1.12
# via scrapegraphai
langchain-community==0.2.10
# via scrapegraphai
langchain-core==0.2.28
langchain-core==0.2.33
# via langchain
# via langchain-anthropic
# via langchain-aws
Expand Down Expand Up @@ -224,7 +225,7 @@ langchain-mistralai==0.1.12
# via scrapegraphai
langchain-nvidia-ai-endpoints==0.1.7
# via scrapegraphai
langchain-openai==0.1.17
langchain-openai==0.1.22
# via scrapegraphai
langchain-text-splitters==0.2.2
# via langchain
Expand Down Expand Up @@ -264,7 +265,7 @@ numpy==1.26.4
# via sentence-transformers
# via shapely
# via transformers
openai==1.37.0
openai==1.41.0
# via langchain-fireworks
# via langchain-openai
orjson==3.10.6
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
langchain>=0.2.10
langchain>=0.2.14
langchain-fireworks>=0.1.3
langchain_community>=0.2.9
langchain-google-genai>=1.0.7
langchain-google-vertexai>=1.0.7
langchain-openai>=0.1.17
langchain-openai>=0.1.22
langchain-groq>=0.1.3
langchain-aws>=0.1.3
langchain-anthropic>=0.1.11
Expand Down
2 changes: 1 addition & 1 deletion scrapegraphai/helpers/models_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"gpt-4-32k-0613": 32768,
"gpt-4o": 128000,
"gpt-4o-mini":128000,
"chatgpt-4o-latest":128000
"chatgpt-4o-latest": 128000
},
"google_genai": {
"gemini-pro": 128000,
Expand Down
16 changes: 14 additions & 2 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_mistralai import ChatMistralAI
from langchain_anthropic import ChatAnthropic
from langchain_groq import ChatGroq
from langchain_fireworks import ChatFireworks
from langchain_google_vertexai import ChatVertexAI
from langchain_community.chat_models import ChatOllama
from tqdm import tqdm
from ..utils.logging import get_logger
Expand Down Expand Up @@ -88,12 +93,19 @@ def execute(self, state: dict) -> dict:
# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

# Use built-in structured output for providers that allow it
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI, ChatAnthropic, ChatFireworks, ChatGroq, ChatVertexAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="json_schema")

else:
output_parser = JsonOutputParser()

format_instructions = output_parser.get_format_instructions()

if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
template_chunks_prompt = TEMPLATE_CHUNKS_MD
template_merge_prompt = TEMPLATE_MERGE_MD
Expand Down
11 changes: 6 additions & 5 deletions scrapegraphai/nodes/parse_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from semchunk import chunk
from langchain_community.document_transformers import Html2TextTransformer
from langchain_core.documents import Document
from ..utils.logging import get_logger
from .base_node import BaseNode

class ParseNode(BaseNode):
Expand Down Expand Up @@ -78,16 +77,18 @@ def execute(self, state: dict) -> dict:
else:
docs_transformed = docs_transformed[0]

# Adapt the chunk size, leaving room for the reply, the prompt and the schema
chunk_size = self.node_config.get("chunk_size", 4096)
chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))

if isinstance(docs_transformed, Document):

chunks = chunk(text=docs_transformed.page_content,
chunk_size=self.node_config.get("chunk_size", 4096)-250,
chunk_size=chunk_size,
token_counter=lambda text: len(text.split()),
memoize=False)
else:

chunks = chunk(text=docs_transformed,
chunk_size=self.node_config.get("chunk_size", 4096)-250,
chunk_size=chunk_size,
token_counter=lambda text: len(text.split()),
memoize=False)

Expand Down
4 changes: 2 additions & 2 deletions scrapegraphai/utils/token_calculator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Module for truncatinh in chunks the messages
Module for truncating in chunks the messages
"""
from typing import List
import tiktoken
Expand Down Expand Up @@ -27,7 +27,7 @@ def truncate_text_tokens(text: str, model: str, encoding_name: str) -> List[str]
"""

encoding = tiktoken.get_encoding(encoding_name)
max_tokens = models_tokens[model] - 500
max_tokens = min(models_tokens[model] - 500, int(models_tokens[model] * 0.9))
encoded_text = encoding.encode(text)

chunks = [encoded_text[i:i + max_tokens]
Expand Down

0 comments on commit d1f6b9f

Please sign in to comment.