8
8
import os
9
9
from multiprocessing .pool import ThreadPool
10
10
import threading
11
+ import _thread
12
+ from queue import Queue
13
+ import traceback
14
+ import datetime
11
15
12
16
import numpy as np
13
17
import faiss
@@ -60,14 +64,21 @@ def toc(self):
60
64
61
65
def report (self , l ):
62
66
if self .verbose == 1 or (
63
- l > 1000 and time .time () < self .t_display + 1.0 ):
67
+ self .verbose == 2 and (
68
+ l > 1000 and time .time () < self .t_display + 1.0
69
+ )
70
+ ):
64
71
return
72
+ t = time .time () - self .t0
65
73
print (
66
- f"[{ time . time () - self . t0 :.1f} s] list { l } /{ self .index .nlist } "
74
+ f"[{ t :.1f} s] list { l } /{ self .index .nlist } "
67
75
f"times prep q { self .t_accu [0 ]:.3f} prep b { self .t_accu [1 ]:.3f} "
68
76
f"comp { self .t_accu [2 ]:.3f} res { self .t_accu [3 ]:.3f} "
69
- f"wait { self .t_accu [4 ]:.3f} " ,
70
- end = "\r " , flush = True
77
+ f"wait { self .t_accu [4 ]:.3f} "
78
+ f"eta { datetime .timedelta (seconds = t * self .index .nlist / (l + 1 )- t )} "
79
+ f"mem { faiss .get_mem_usage_kb ()} " ,
80
+ end = "\r " if self .verbose <= 2 else "\n " ,
81
+ flush = True ,
71
82
)
72
83
self .t_display = time .time ()
73
84
@@ -141,24 +152,25 @@ def add_results_to_heap(self, q_subset, D, list_ids, I):
141
152
def sizes_in_checkpoint (self ):
142
153
return (self .xq .shape , self .index .nprobe , self .index .nlist )
143
154
144
- def write_checkpoint (self , fname , cur_list_no ):
155
+ def write_checkpoint (self , fname , completed ):
145
156
# write to temp file then move to final file
146
157
tmpname = fname + ".tmp"
147
- pickle . dump (
148
- {
149
- "sizes" : self . sizes_in_checkpoint (),
150
- "cur_list_no " : cur_list_no ,
151
- "rh " : ( self . rh . D , self . rh . I ) ,
152
- }, open ( tmpname , "wb" ), - 1
153
- )
158
+ with open ( tmpname , "wb" ) as f :
159
+ pickle . dump (
160
+ {
161
+ "sizes " : self . sizes_in_checkpoint () ,
162
+ "completed " : completed ,
163
+ "rh" : ( self . rh . D , self . rh . I ),
164
+ }, f , - 1 )
154
165
os .replace (tmpname , fname )
155
166
156
167
def read_checkpoint (self , fname ):
157
- ckp = pickle .load (open (fname , "rb" ))
168
+ with open (fname , "rb" ) as f :
169
+ ckp = pickle .load (f )
158
170
assert ckp ["sizes" ] == self .sizes_in_checkpoint ()
159
171
self .rh .D [:] = ckp ["rh" ][0 ]
160
172
self .rh .I [:] = ckp ["rh" ][1 ]
161
- return ckp ["cur_list_no " ]
173
+ return ckp ["completed " ]
162
174
163
175
164
176
class BlockComputer :
@@ -225,11 +237,11 @@ def big_batch_search(
225
237
verbose = 0 ,
226
238
threaded = 0 ,
227
239
use_float16 = False ,
228
- prefetch_threads = 8 ,
229
- computation_threads = 0 ,
240
+ prefetch_threads = 1 ,
241
+ computation_threads = 1 ,
230
242
q_assign = None ,
231
243
checkpoint = None ,
232
- checkpoint_freq = 64 ,
244
+ checkpoint_freq = 7200 ,
233
245
start_list = 0 ,
234
246
end_list = None ,
235
247
crash_at = - 1
@@ -251,7 +263,7 @@ def big_batch_search(
251
263
252
264
threaded=0: sequential execution
253
265
threaded=1: prefetch next bucket while computing the current one
254
- threaded>1 : prefetch this many buckets at a time.
266
+ threaded=2 : prefetch prefetch_threads buckets at a time.
255
267
256
268
compute_threads>1: the knn function will get an additional thread_no that
257
269
tells which worker should handle this.
@@ -311,12 +323,13 @@ def big_batch_search(
311
323
if end_list is None :
312
324
end_list = index .nlist
313
325
326
+ completed = set ()
314
327
if checkpoint is not None :
315
328
assert (start_list , end_list ) == (0 , index .nlist )
316
329
if os .path .exists (checkpoint ):
317
330
print ("recovering checkpoint" , checkpoint )
318
- start_list = bbs .read_checkpoint (checkpoint )
319
- print (" start at list " , start_list )
331
+ completed = bbs .read_checkpoint (checkpoint )
332
+ print (" already completed " , len ( completed ) )
320
333
else :
321
334
print ("no checkpoint: starting from scratch" )
322
335
@@ -363,94 +376,130 @@ def add_results_and_prefetch(to_add, l):
363
376
bbs .add_results_to_heap (* to_add )
364
377
pool .close ()
365
378
else :
366
- # run by batches with parallel prefetch and parallel comp
367
- list_step = threaded
368
- assert start_list % list_step == 0
369
379
370
- if prefetch_threads == 0 :
371
- prefetch_map = map
372
- else :
373
- prefetch_pool = ThreadPool (prefetch_threads )
374
- prefetch_map = prefetch_pool .map
375
-
376
- if computation_threads > 0 :
377
- comp_pool = ThreadPool (computation_threads )
378
-
379
- def add_results_and_prefetch_batch (to_add , l ):
380
- def add_results (to_add ):
381
- for ta in to_add : # this one cannot be run in parallel...
382
- if ta is not None :
383
- bbs .add_results_to_heap (* ta )
384
- if prefetch_threads == 0 :
385
- add_results (to_add )
386
- else :
387
- add_a = prefetch_pool .apply_async (add_results , (to_add , ))
388
- next_lists = range (l , min (l + list_step , index .nlist ))
389
- res = list (prefetch_map (bbs .prepare_bucket , next_lists ))
390
- if prefetch_threads > 0 :
391
- add_a .get ()
392
- return res
393
-
394
- # used only when computation_threads > 1
395
- thread_id_to_seq_lock = threading .Lock ()
396
- thread_id_to_seq = {}
397
-
398
- def do_comp (bucket ):
399
- (q_subset , xq_l , list_ids , xb_l ) = bucket
380
+ def task_manager_thread (
381
+ task ,
382
+ pool_size ,
383
+ start_task ,
384
+ end_task ,
385
+ completed ,
386
+ output_queue ,
387
+ input_queue ,
388
+ ):
400
389
try :
401
- tid = thread_id_to_seq [threading .get_ident ()]
402
- except KeyError :
403
- with thread_id_to_seq_lock :
404
- tid = len (thread_id_to_seq )
405
- thread_id_to_seq [threading .get_ident ()] = tid
406
- D , I = comp .block_search (xq_l , xb_l , list_ids , k , thread_id = tid )
407
- return q_subset , D , list_ids , I
408
-
409
- prefetched_buckets = add_results_and_prefetch_batch ([], start_list )
410
- to_add = []
411
- pool = ThreadPool (1 )
412
- prefetched_buckets_a = None
413
-
414
- # loop over inverted lists
415
- for l in range (start_list , end_list , list_step ):
416
- bbs .report (l )
417
- buckets = prefetched_buckets
418
- prefetched_buckets_a = pool .apply_async (
419
- add_results_and_prefetch_batch , (to_add , l + list_step ))
420
-
421
- bbs .start_t_accu ()
422
-
423
- to_add = []
424
- if computation_threads == 0 :
425
- for q_subset , xq_l , list_ids , xb_l in buckets :
426
- D , I = comp .block_search (xq_l , xb_l , list_ids , k )
427
- to_add .append ((q_subset , D , list_ids , I ))
428
- else :
429
- to_add = list (comp_pool .map (do_comp , buckets ))
430
-
431
- bbs .stop_t_accu (2 )
390
+ with ThreadPool (pool_size ) as pool :
391
+ res = [pool .apply_async (
392
+ task ,
393
+ args = (i , output_queue , input_queue ))
394
+ for i in range (start_task , end_task )
395
+ if i not in completed ]
396
+ for r in res :
397
+ r .get ()
398
+ pool .close ()
399
+ pool .join ()
400
+ output_queue .put (None )
401
+ except :
402
+ traceback .print_exc ()
403
+ _thread .interrupt_main ()
404
+ raise
405
+
406
+ def task_manager (* args ):
407
+ task_manager = threading .Thread (
408
+ target = task_manager_thread ,
409
+ args = args ,
410
+ )
411
+ task_manager .daemon = True
412
+ task_manager .start ()
413
+ return task_manager
414
+
415
+ def prepare_task (task_id , output_queue , input_queue = None ):
416
+ try :
417
+ # print(f"Prepare start: {task_id}")
418
+ q_subset , xq_l , list_ids , xb_l = bbs .prepare_bucket (task_id )
419
+ output_queue .put ((task_id , q_subset , xq_l , list_ids , xb_l ))
420
+ # print(f"Prepare end: {task_id}")
421
+ except :
422
+ traceback .print_exc ()
423
+ _thread .interrupt_main ()
424
+ raise
425
+
426
+ def compute_task (task_id , output_queue , input_queue ):
427
+ try :
428
+ # print(f"Compute start: {task_id}")
429
+ t_wait = 0
430
+ while True :
431
+ t0 = time .time ()
432
+ input_value = input_queue .get ()
433
+ t_wait += time .time () - t0
434
+ if input_value is None :
435
+ # signal for other compute tasks
436
+ input_queue .put (None )
437
+ break
438
+ centroid , q_subset , xq_l , list_ids , xb_l = input_value
439
+ # print(f'Compute work start: task {task_id}, centroid {centroid}')
440
+ t0 = time .time ()
441
+ if computation_threads > 1 :
442
+ D , I = comp .block_search (
443
+ xq_l , xb_l , list_ids , k , thread_id = task_id
444
+ )
445
+ else :
446
+ D , I = comp .block_search (xq_l , xb_l , list_ids , k )
447
+ t_compute = time .time () - t0
448
+ # print(f'Compute work end: task {task_id}, centroid {centroid}')
449
+ t0 = time .time ()
450
+ output_queue .put (
451
+ (centroid , t_wait , t_compute , q_subset , D , list_ids , I )
452
+ )
453
+ t_wait = time .time () - t0
454
+ # print(f"Compute end: {task_id}")
455
+ except :
456
+ traceback .print_exc ()
457
+ _thread .interrupt_main ()
458
+ raise
459
+
460
+ prepare_to_compute_queue = Queue (2 )
461
+ compute_to_main_queue = Queue (2 )
462
+ compute_task_manager = task_manager (
463
+ compute_task ,
464
+ computation_threads ,
465
+ 0 ,
466
+ computation_threads ,
467
+ set (),
468
+ compute_to_main_queue ,
469
+ prepare_to_compute_queue ,
470
+ )
471
+ prepare_task_manager = task_manager (
472
+ prepare_task ,
473
+ prefetch_threads ,
474
+ start_list ,
475
+ end_list ,
476
+ completed ,
477
+ prepare_to_compute_queue ,
478
+ None ,
479
+ )
432
480
481
+ t_checkpoint = time .time ()
482
+ while True :
483
+ value = compute_to_main_queue .get ()
484
+ if not value :
485
+ break
486
+ centroid , t_wait , t_compute , q_subset , D , list_ids , I = value
433
487
# to test checkpointing
434
- if l == crash_at :
488
+ if centroid == crash_at :
435
489
1 / 0
436
-
437
- bbs .start_t_accu ()
438
- prefetched_buckets = prefetched_buckets_a . get ( )
439
- bbs . stop_t_accu ( 4 )
440
-
490
+ bbs . t_accu [ 2 ] += t_compute
491
+ bbs .t_accu [ 4 ] += t_wait
492
+ bbs . add_results_to_heap ( q_subset , D , list_ids , I )
493
+ completed . add ( centroid )
494
+ bbs . report ( centroid )
441
495
if checkpoint is not None :
442
- if (l // list_step ) % checkpoint_freq == 0 :
443
- print ("writing checkpoint %s" % l )
444
- bbs .write_checkpoint (checkpoint , l )
496
+ if time .time () - t_checkpoint > checkpoint_freq :
497
+ print ("writing checkpoint" )
498
+ bbs .write_checkpoint (checkpoint , completed )
499
+ t_checkpoint = time .time ()
445
500
446
- # flush add
447
- for ta in to_add :
448
- bbs .add_results_to_heap (* ta )
449
- pool .close ()
450
- if prefetch_threads != 0 :
451
- prefetch_pool .close ()
452
- if computation_threads != 0 :
453
- comp_pool .close ()
501
+ prepare_task_manager .join ()
502
+ compute_task_manager .join ()
454
503
455
504
bbs .tic ("finalize heap" )
456
505
bbs .rh .finalize ()
0 commit comments