Skip to content

Commit 1f465e6

Browse files
committed
feat: refactoring of generate answer node
1 parent 3e8c043 commit 1f465e6

File tree

1 file changed

+43
-3
lines changed

1 file changed

+43
-3
lines changed

scrapegraphai/nodes/generate_answer_node.py

+43-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
1717
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD
1818
)
19+
from langchain.callbacks.manager import CallbackManager
20+
from langchain.callbacks import get_openai_callback
21+
from requests.exceptions import Timeout
22+
import time
1923

2024
class GenerateAnswerNode(BaseNode):
2125
"""
@@ -56,6 +60,7 @@ def __init__(
5660
self.script_creator = node_config.get("script_creator", False)
5761
self.is_md_scraper = node_config.get("is_md_scraper", False)
5862
self.additional_info = node_config.get("additional_info")
63+
self.timeout = node_config.get("timeout", 30)
5964

6065
def execute(self, state: dict) -> dict:
6166
"""
@@ -114,14 +119,33 @@ def execute(self, state: dict) -> dict:
114119
template_chunks_prompt = self.additional_info + template_chunks_prompt
115120
template_merge_prompt = self.additional_info + template_merge_prompt
116121

122+
def invoke_with_timeout(chain, inputs, timeout):
123+
try:
124+
with get_openai_callback() as cb:
125+
start_time = time.time()
126+
response = chain.invoke(inputs)
127+
if time.time() - start_time > timeout:
128+
raise Timeout(f"Response took longer than {timeout} seconds")
129+
return response
130+
except Timeout as e:
131+
self.logger.error(f"Timeout error: {str(e)}")
132+
raise
133+
except Exception as e:
134+
self.logger.error(f"Error during chain execution: {str(e)}")
135+
raise
136+
117137
if len(doc) == 1:
118138
prompt = PromptTemplate(
119139
template=template_no_chunks_prompt,
120140
input_variables=["question"],
121141
partial_variables={"context": doc, "format_instructions": format_instructions}
122142
)
123143
chain = prompt | self.llm_model
124-
raw_response = chain.invoke({"question": user_prompt})
144+
try:
145+
raw_response = invoke_with_timeout(chain, {"question": user_prompt}, self.timeout)
146+
except Timeout:
147+
state.update({self.output[0]: {"error": "Response timeout exceeded"}})
148+
return state
125149

126150
if output_parser:
127151
try:
@@ -155,7 +179,15 @@ def execute(self, state: dict) -> dict:
155179
chains_dict[chain_name] = chains_dict[chain_name] | output_parser
156180

157181
async_runner = RunnableParallel(**chains_dict)
158-
batch_results = async_runner.invoke({"question": user_prompt})
182+
try:
183+
batch_results = invoke_with_timeout(
184+
async_runner,
185+
{"question": user_prompt},
186+
self.timeout
187+
)
188+
except Timeout:
189+
state.update({self.output[0]: {"error": "Response timeout exceeded during chunk processing"}})
190+
return state
159191

160192
merge_prompt = PromptTemplate(
161193
template=template_merge_prompt,
@@ -166,7 +198,15 @@ def execute(self, state: dict) -> dict:
166198
merge_chain = merge_prompt | self.llm_model
167199
if output_parser:
168200
merge_chain = merge_chain | output_parser
169-
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
201+
try:
202+
answer = invoke_with_timeout(
203+
merge_chain,
204+
{"context": batch_results, "question": user_prompt},
205+
self.timeout
206+
)
207+
except Timeout:
208+
state.update({self.output[0]: {"error": "Response timeout exceeded during merge"}})
209+
return state
170210

171211
state.update({self.output[0]: answer})
172212
return state

0 commit comments

Comments
 (0)