16
16
TEMPLATE_CHUNKS , TEMPLATE_NO_CHUNKS , TEMPLATE_MERGE ,
17
17
TEMPLATE_CHUNKS_MD , TEMPLATE_NO_CHUNKS_MD , TEMPLATE_MERGE_MD
18
18
)
19
+ from langchain .callbacks .manager import CallbackManager
20
+ from langchain .callbacks import get_openai_callback
21
+ from requests .exceptions import Timeout
22
+ import time
19
23
20
24
class GenerateAnswerNode (BaseNode ):
21
25
"""
@@ -56,6 +60,7 @@ def __init__(
56
60
self .script_creator = node_config .get ("script_creator" , False )
57
61
self .is_md_scraper = node_config .get ("is_md_scraper" , False )
58
62
self .additional_info = node_config .get ("additional_info" )
63
+ self .timeout = node_config .get ("timeout" , 30 )
59
64
60
65
def execute (self , state : dict ) -> dict :
61
66
"""
@@ -114,14 +119,33 @@ def execute(self, state: dict) -> dict:
114
119
template_chunks_prompt = self .additional_info + template_chunks_prompt
115
120
template_merge_prompt = self .additional_info + template_merge_prompt
116
121
122
+ def invoke_with_timeout (chain , inputs , timeout ):
123
+ try :
124
+ with get_openai_callback () as cb :
125
+ start_time = time .time ()
126
+ response = chain .invoke (inputs )
127
+ if time .time () - start_time > timeout :
128
+ raise Timeout (f"Response took longer than { timeout } seconds" )
129
+ return response
130
+ except Timeout as e :
131
+ self .logger .error (f"Timeout error: { str (e )} " )
132
+ raise
133
+ except Exception as e :
134
+ self .logger .error (f"Error during chain execution: { str (e )} " )
135
+ raise
136
+
117
137
if len (doc ) == 1 :
118
138
prompt = PromptTemplate (
119
139
template = template_no_chunks_prompt ,
120
140
input_variables = ["question" ],
121
141
partial_variables = {"context" : doc , "format_instructions" : format_instructions }
122
142
)
123
143
chain = prompt | self .llm_model
124
- raw_response = chain .invoke ({"question" : user_prompt })
144
+ try :
145
+ raw_response = invoke_with_timeout (chain , {"question" : user_prompt }, self .timeout )
146
+ except Timeout :
147
+ state .update ({self .output [0 ]: {"error" : "Response timeout exceeded" }})
148
+ return state
125
149
126
150
if output_parser :
127
151
try :
@@ -155,7 +179,15 @@ def execute(self, state: dict) -> dict:
155
179
chains_dict [chain_name ] = chains_dict [chain_name ] | output_parser
156
180
157
181
async_runner = RunnableParallel (** chains_dict )
158
- batch_results = async_runner .invoke ({"question" : user_prompt })
182
+ try :
183
+ batch_results = invoke_with_timeout (
184
+ async_runner ,
185
+ {"question" : user_prompt },
186
+ self .timeout
187
+ )
188
+ except Timeout :
189
+ state .update ({self .output [0 ]: {"error" : "Response timeout exceeded during chunk processing" }})
190
+ return state
159
191
160
192
merge_prompt = PromptTemplate (
161
193
template = template_merge_prompt ,
@@ -166,7 +198,15 @@ def execute(self, state: dict) -> dict:
166
198
merge_chain = merge_prompt | self .llm_model
167
199
if output_parser :
168
200
merge_chain = merge_chain | output_parser
169
- answer = merge_chain .invoke ({"context" : batch_results , "question" : user_prompt })
201
+ try :
202
+ answer = invoke_with_timeout (
203
+ merge_chain ,
204
+ {"context" : batch_results , "question" : user_prompt },
205
+ self .timeout
206
+ )
207
+ except Timeout :
208
+ state .update ({self .output [0 ]: {"error" : "Response timeout exceeded during merge" }})
209
+ return state
170
210
171
211
state .update ({self .output [0 ]: answer })
172
212
return state
0 commit comments