Skip to content

Commit

Permalink
fix: generate answer node timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
VinciGit00 committed Nov 20, 2024
1 parent 86bf4f2 commit 32ef554
Showing 1 changed file with 23 additions and 28 deletions.
51 changes: 23 additions & 28 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def __init__(
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"]

if hasattr(self.llm_model, 'request_timeout'):
self.llm_model.request_timeout = node_config.get("timeout", 30)

if isinstance(node_config["llm_model"], ChatOllama):
self.llm_model.format = "json"

Expand All @@ -63,7 +60,22 @@ def __init__(
self.script_creator = node_config.get("script_creator", False)
self.is_md_scraper = node_config.get("is_md_scraper", False)
self.additional_info = node_config.get("additional_info")
self.timeout = node_config.get("timeout", 30)
self.timeout = node_config.get("timeout", 120)

def invoke_with_timeout(self, chain, inputs, timeout):
"""Helper method to invoke chain with timeout"""
try:
start_time = time.time()
response = chain.invoke(inputs)
if time.time() - start_time > timeout:
raise Timeout(f"Response took longer than {timeout} seconds")
return response
except Timeout as e:
self.logger.error(f"Timeout error: {str(e)}")
raise
except Exception as e:
self.logger.error(f"Error during chain execution: {str(e)}")
raise

def execute(self, state: dict) -> dict:
"""
Expand Down Expand Up @@ -119,39 +131,22 @@ def execute(self, state: dict) -> dict:
template_chunks_prompt = self.additional_info + template_chunks_prompt
template_merge_prompt = self.additional_info + template_merge_prompt

def invoke_with_timeout(chain, inputs, timeout):
try:
with get_openai_callback() as cb:
start_time = time.time()
response = chain.invoke(inputs)
if time.time() - start_time > timeout:
raise Timeout(f"Response took longer than {timeout} seconds")
return response
except Timeout as e:
self.logger.error(f"Timeout error: {str(e)}")
raise
except Exception as e:
self.logger.error(f"Error during chain execution: {str(e)}")
raise

if len(doc) == 1:
prompt = PromptTemplate(
template=template_no_chunks_prompt,
input_variables=["question"],
partial_variables={"context": doc, "format_instructions": format_instructions}
)
chain = prompt | self.llm_model
if output_parser:
chain = chain | output_parser

try:
raw_response = invoke_with_timeout(chain, {"question": user_prompt}, self.timeout)
answer = self.invoke_with_timeout(chain, {"question": user_prompt}, self.timeout)
except Timeout:
state.update({self.output[0]: {"error": "Response timeout exceeded"}})
return state

if output_parser:
chain = chain | output_parser

answer = chain.invoke({"question": user_prompt})
state.update({self.output[0]: answer})
return state

Expand All @@ -171,9 +166,9 @@ def invoke_with_timeout(chain, inputs, timeout):

async_runner = RunnableParallel(**chains_dict)
try:
batch_results = invoke_with_timeout(
async_runner,
{"question": user_prompt},
batch_results = self.invoke_with_timeout(
async_runner,
{"question": user_prompt},
self.timeout
)
except Timeout:
Expand All @@ -190,7 +185,7 @@ def invoke_with_timeout(chain, inputs, timeout):
if output_parser:
merge_chain = merge_chain | output_parser
try:
answer = invoke_with_timeout(
answer = self.invoke_with_timeout(
merge_chain,
{"context": batch_results, "question": user_prompt},
self.timeout
Expand Down

0 comments on commit 32ef554

Please sign in to comment.