Skip to content

Commit 32ef554

Browse files
committed
fix: generate answer node timeout
1 parent 86bf4f2 commit 32ef554

File tree

1 file changed

+23
-28
lines changed

1 file changed

+23
-28
lines changed

scrapegraphai/nodes/generate_answer_node.py

+23-28
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ def __init__(
5252
super().__init__(node_name, "node", input, output, 2, node_config)
5353
self.llm_model = node_config["llm_model"]
5454

55-
if hasattr(self.llm_model, 'request_timeout'):
56-
self.llm_model.request_timeout = node_config.get("timeout", 30)
57-
5855
if isinstance(node_config["llm_model"], ChatOllama):
5956
self.llm_model.format = "json"
6057

@@ -63,7 +60,22 @@ def __init__(
6360
self.script_creator = node_config.get("script_creator", False)
6461
self.is_md_scraper = node_config.get("is_md_scraper", False)
6562
self.additional_info = node_config.get("additional_info")
66-
self.timeout = node_config.get("timeout", 30)
63+
self.timeout = node_config.get("timeout", 120)
64+
65+
def invoke_with_timeout(self, chain, inputs, timeout):
66+
"""Helper method to invoke chain with timeout"""
67+
try:
68+
start_time = time.time()
69+
response = chain.invoke(inputs)
70+
if time.time() - start_time > timeout:
71+
raise Timeout(f"Response took longer than {timeout} seconds")
72+
return response
73+
except Timeout as e:
74+
self.logger.error(f"Timeout error: {str(e)}")
75+
raise
76+
except Exception as e:
77+
self.logger.error(f"Error during chain execution: {str(e)}")
78+
raise
6779

6880
def execute(self, state: dict) -> dict:
6981
"""
@@ -119,39 +131,22 @@ def execute(self, state: dict) -> dict:
119131
template_chunks_prompt = self.additional_info + template_chunks_prompt
120132
template_merge_prompt = self.additional_info + template_merge_prompt
121133

122-
def invoke_with_timeout(chain, inputs, timeout):
123-
try:
124-
with get_openai_callback() as cb:
125-
start_time = time.time()
126-
response = chain.invoke(inputs)
127-
if time.time() - start_time > timeout:
128-
raise Timeout(f"Response took longer than {timeout} seconds")
129-
return response
130-
except Timeout as e:
131-
self.logger.error(f"Timeout error: {str(e)}")
132-
raise
133-
except Exception as e:
134-
self.logger.error(f"Error during chain execution: {str(e)}")
135-
raise
136-
137134
if len(doc) == 1:
138135
prompt = PromptTemplate(
139136
template=template_no_chunks_prompt,
140137
input_variables=["question"],
141138
partial_variables={"context": doc, "format_instructions": format_instructions}
142139
)
143140
chain = prompt | self.llm_model
141+
if output_parser:
142+
chain = chain | output_parser
144143

145144
try:
146-
raw_response = invoke_with_timeout(chain, {"question": user_prompt}, self.timeout)
145+
answer = self.invoke_with_timeout(chain, {"question": user_prompt}, self.timeout)
147146
except Timeout:
148147
state.update({self.output[0]: {"error": "Response timeout exceeded"}})
149148
return state
150149

151-
if output_parser:
152-
chain = chain | output_parser
153-
154-
answer = chain.invoke({"question": user_prompt})
155150
state.update({self.output[0]: answer})
156151
return state
157152

@@ -171,9 +166,9 @@ def invoke_with_timeout(chain, inputs, timeout):
171166

172167
async_runner = RunnableParallel(**chains_dict)
173168
try:
174-
batch_results = invoke_with_timeout(
175-
async_runner,
176-
{"question": user_prompt},
169+
batch_results = self.invoke_with_timeout(
170+
async_runner,
171+
{"question": user_prompt},
177172
self.timeout
178173
)
179174
except Timeout:
@@ -190,7 +185,7 @@ def invoke_with_timeout(chain, inputs, timeout):
190185
if output_parser:
191186
merge_chain = merge_chain | output_parser
192187
try:
193-
answer = invoke_with_timeout(
188+
answer = self.invoke_with_timeout(
194189
merge_chain,
195190
{"context": batch_results, "question": user_prompt},
196191
self.timeout

0 commit comments

Comments
 (0)