@@ -52,9 +52,6 @@ def __init__(
52
52
super ().__init__ (node_name , "node" , input , output , 2 , node_config )
53
53
self .llm_model = node_config ["llm_model" ]
54
54
55
- if hasattr (self .llm_model , 'request_timeout' ):
56
- self .llm_model .request_timeout = node_config .get ("timeout" , 30 )
57
-
58
55
if isinstance (node_config ["llm_model" ], ChatOllama ):
59
56
self .llm_model .format = "json"
60
57
@@ -63,7 +60,22 @@ def __init__(
63
60
self .script_creator = node_config .get ("script_creator" , False )
64
61
self .is_md_scraper = node_config .get ("is_md_scraper" , False )
65
62
self .additional_info = node_config .get ("additional_info" )
66
- self .timeout = node_config .get ("timeout" , 30 )
63
+ self .timeout = node_config .get ("timeout" , 120 )
64
+
65
+ def invoke_with_timeout (self , chain , inputs , timeout ):
66
+ """Helper method to invoke chain with timeout"""
67
+ try :
68
+ start_time = time .time ()
69
+ response = chain .invoke (inputs )
70
+ if time .time () - start_time > timeout :
71
+ raise Timeout (f"Response took longer than { timeout } seconds" )
72
+ return response
73
+ except Timeout as e :
74
+ self .logger .error (f"Timeout error: { str (e )} " )
75
+ raise
76
+ except Exception as e :
77
+ self .logger .error (f"Error during chain execution: { str (e )} " )
78
+ raise
67
79
68
80
def execute (self , state : dict ) -> dict :
69
81
"""
@@ -119,39 +131,22 @@ def execute(self, state: dict) -> dict:
119
131
template_chunks_prompt = self .additional_info + template_chunks_prompt
120
132
template_merge_prompt = self .additional_info + template_merge_prompt
121
133
122
- def invoke_with_timeout (chain , inputs , timeout ):
123
- try :
124
- with get_openai_callback () as cb :
125
- start_time = time .time ()
126
- response = chain .invoke (inputs )
127
- if time .time () - start_time > timeout :
128
- raise Timeout (f"Response took longer than { timeout } seconds" )
129
- return response
130
- except Timeout as e :
131
- self .logger .error (f"Timeout error: { str (e )} " )
132
- raise
133
- except Exception as e :
134
- self .logger .error (f"Error during chain execution: { str (e )} " )
135
- raise
136
-
137
134
if len (doc ) == 1 :
138
135
prompt = PromptTemplate (
139
136
template = template_no_chunks_prompt ,
140
137
input_variables = ["question" ],
141
138
partial_variables = {"context" : doc , "format_instructions" : format_instructions }
142
139
)
143
140
chain = prompt | self .llm_model
141
+ if output_parser :
142
+ chain = chain | output_parser
144
143
145
144
try :
146
- raw_response = invoke_with_timeout (chain , {"question" : user_prompt }, self .timeout )
145
+ answer = self . invoke_with_timeout (chain , {"question" : user_prompt }, self .timeout )
147
146
except Timeout :
148
147
state .update ({self .output [0 ]: {"error" : "Response timeout exceeded" }})
149
148
return state
150
149
151
- if output_parser :
152
- chain = chain | output_parser
153
-
154
- answer = chain .invoke ({"question" : user_prompt })
155
150
state .update ({self .output [0 ]: answer })
156
151
return state
157
152
@@ -171,9 +166,9 @@ def invoke_with_timeout(chain, inputs, timeout):
171
166
172
167
async_runner = RunnableParallel (** chains_dict )
173
168
try :
174
- batch_results = invoke_with_timeout (
175
- async_runner ,
176
- {"question" : user_prompt },
169
+ batch_results = self . invoke_with_timeout (
170
+ async_runner ,
171
+ {"question" : user_prompt },
177
172
self .timeout
178
173
)
179
174
except Timeout :
@@ -190,7 +185,7 @@ def invoke_with_timeout(chain, inputs, timeout):
190
185
if output_parser :
191
186
merge_chain = merge_chain | output_parser
192
187
try :
193
- answer = invoke_with_timeout (
188
+ answer = self . invoke_with_timeout (
194
189
merge_chain ,
195
190
{"context" : batch_results , "question" : user_prompt },
196
191
self .timeout
0 commit comments