23
23
import time
24
24
import json
25
25
import signal
26
- from . DAG import Task , DAG
26
+ from DAGflow import DAG
27
27
28
28
29
29
LOG = logging .getLogger (__name__ )
@@ -98,7 +98,7 @@ def ps():
98
98
return r
99
99
100
100
101
- def update_task_status (tasks ):
101
+ def update_task_status (tasks , stop_on_failure ):
102
102
"""
103
103
104
104
:param tasks:
@@ -133,12 +133,19 @@ def update_task_status(tasks):
133
133
134
134
# check recent done tasks on sge
135
135
if task .type == "sge" and task .run_id not in sge_running_task :
136
- task .check_done ()
136
+ status = task .check_done ()
137
+
138
+ if not status and stop_on_failure :
139
+ LOG .info ("Task %r failed, stop all tasks" % task .id )
140
+ del_online_tasks ()
137
141
continue
138
142
elif task .type == "local" :
139
- if task .run_id .poll ():
140
- task .check_done ()
143
+ if not task .run_id .poll ():
144
+ status = task .check_done ()
141
145
146
+ if not status and stop_on_failure :
147
+ LOG .info ("Task %r failed, stop all tasks" % task .id )
148
+ del_online_tasks ()
142
149
continue
143
150
else :
144
151
pass
@@ -192,7 +199,11 @@ def submit_tasks(tasks, concurrent_tasks):
192
199
return tasks
193
200
194
201
195
- def qdel_online_tasks (signum , frame ):
202
+ def del_task_hander (signum , frame ):
203
+ del_online_tasks ()
204
+
205
+
206
+ def del_online_tasks ():
196
207
LOG .info ("delete all running jobs, please wait" )
197
208
time .sleep (3 )
198
209
@@ -201,44 +212,37 @@ def qdel_online_tasks(signum, frame):
201
212
if task .status == "running" :
202
213
task .kill ()
203
214
204
- write_tasks (TASKS , TASK_NAME + ".json" )
215
+ write_tasks (TASKS )
205
216
206
217
sys .exit ("sorry, the program exit" )
207
218
208
219
209
- def write_tasks (tasks , filename ):
220
+ def write_tasks (tasks ):
210
221
failed_tasks = []
211
222
212
- tasks_json = OrderedDict ()
213
-
214
223
for id , task in tasks .items ():
215
224
216
225
if task .status != "success" :
217
226
failed_tasks .append (task .id )
218
227
219
- tasks_json .update (task .to_json ())
220
-
221
- with open (filename , "w" ) as out :
222
- json .dump (tasks_json , out , indent = 2 )
223
-
224
228
if failed_tasks :
225
229
LOG .info ("""\
226
230
The following tasks were failed:
227
231
%s
228
- The tasks were save in %s, you can resub it.
229
- """ % ("\n " .join ([i for i in failed_tasks ]), filename ))
232
+ """ % "\n " .join ([i for i in failed_tasks ]))
230
233
sys .exit ("sorry, the program exit with some jobs failed" )
231
234
else :
232
235
LOG .info ("All jobs were done!" )
233
236
234
237
235
- def do_dag (dag , concurrent_tasks , refresh_time , log_name = "" ):
238
+ def do_dag (dag , concurrent_tasks = 200 , refresh_time = 60 , stop_on_failure = False ):
236
239
240
+ dag .to_json ()
237
241
start = time .time ()
238
242
239
243
logging .basicConfig (level = logging .DEBUG ,
240
244
format = "[%(levelname)s] %(asctime)s %(message)s" ,
241
- filename = log_name ,
245
+ filename = "%s.log" % dag . id ,
242
246
filemode = 'w' ,
243
247
)
244
248
@@ -253,15 +257,13 @@ def do_dag(dag, concurrent_tasks, refresh_time, log_name=""):
253
257
global TASKS
254
258
TASKS = dag .tasks
255
259
256
- signal .signal (signal .SIGINT , qdel_online_tasks )
257
- signal .signal (signal .SIGTERM , qdel_online_tasks )
260
+ signal .signal (signal .SIGINT , del_task_hander )
261
+ signal .signal (signal .SIGTERM , del_task_hander )
258
262
# signal.signal(signal.SIGKILL, qdel_online_tasks)
259
263
260
264
for id , task in TASKS .items ():
261
265
task .init ()
262
266
263
- failed_json = TASK_NAME + ".json"
264
-
265
267
loop = 0
266
268
267
269
while 1 :
@@ -292,10 +294,10 @@ def do_dag(dag, concurrent_tasks, refresh_time, log_name=""):
292
294
else :
293
295
time .sleep (refresh_time )
294
296
loop += 1
295
- update_task_status (TASKS )
297
+ update_task_status (TASKS , stop_on_failure )
296
298
297
299
# write failed
298
- write_tasks (TASKS , failed_json )
300
+ write_tasks (TASKS )
299
301
totalTime = time .time () - start
300
302
LOG .info ('Total time:' + time .strftime ("%H:%M:%S" , time .gmtime (totalTime )))
301
303
@@ -314,7 +316,8 @@ def get_args():
314
316
315
317
parser .add_argument ("json" , help = "The json file contain DAG information" )
316
318
parser .add_argument ("-m" , "--max_task" , type = int , default = 200 , help = "concurrent_tasks" )
317
- parser .add_argument ("-r" , "--refresh" , type = int , default = 30 , help = "refresh time of task status (seconds)" )
319
+ parser .add_argument ("-r" , "--refresh" , type = int , default = 60 , help = "refresh time of task status (seconds)" )
320
+ parser .add_argument ("-s" , "--stopOnFailure" , action = "store_true" , help = "stop all tasks when any task failure" )
318
321
args = parser .parse_args ()
319
322
320
323
return args
@@ -327,7 +330,7 @@ def main():
327
330
TASK_NAME = os .path .splitext (os .path .basename (args .json ))[0 ]
328
331
print (TASK_NAME )
329
332
dag = DAG .from_json (args .json )
330
- do_dag (dag , args .max_task , args .refresh , TASK_NAME + ".log" )
333
+ do_dag (dag , args .max_task , args .refresh , args . stopOnFailure )
331
334
332
335
333
336
if __name__ == "__main__" :
0 commit comments