Skip to content

Commit 092606b

Browse files
algoriddlefacebook-github-bot
authored andcommitted
bbs producer/consumer threading (facebookresearch#2901)
Summary: Pull Request resolved: facebookresearch#2901 This diff allows each GPU to work independently, a hot centroid (eg. out-of-distribution queries that hit a centroid heavily) will only block the one GPU that is processing it, others will continue to pick up work independently. Reviewed By: mdouze Differential Revision: D46521298 fbshipit-source-id: 171cb06cce8b2d16b7bd744799b105b3cd525be3
1 parent d8a6350 commit 092606b

File tree

2 files changed

+157
-108
lines changed

2 files changed

+157
-108
lines changed

contrib/big_batch_search.py

+150-101
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
import os
99
from multiprocessing.pool import ThreadPool
1010
import threading
11+
import _thread
12+
from queue import Queue
13+
import traceback
14+
import datetime
1115

1216
import numpy as np
1317
import faiss
@@ -60,14 +64,21 @@ def toc(self):
6064

6165
def report(self, l):
6266
if self.verbose == 1 or (
63-
l > 1000 and time.time() < self.t_display + 1.0):
67+
self.verbose == 2 and (
68+
l > 1000 and time.time() < self.t_display + 1.0
69+
)
70+
):
6471
return
72+
t = time.time() - self.t0
6573
print(
66-
f"[{time.time()-self.t0:.1f} s] list {l}/{self.index.nlist} "
74+
f"[{t:.1f} s] list {l}/{self.index.nlist} "
6775
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
6876
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
69-
f"wait {self.t_accu[4]:.3f}",
70-
end="\r", flush=True
77+
f"wait {self.t_accu[4]:.3f} "
78+
f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} "
79+
f"mem {faiss.get_mem_usage_kb()}",
80+
end="\r" if self.verbose <= 2 else "\n",
81+
flush=True,
7182
)
7283
self.t_display = time.time()
7384

@@ -141,24 +152,25 @@ def add_results_to_heap(self, q_subset, D, list_ids, I):
141152
def sizes_in_checkpoint(self):
142153
return (self.xq.shape, self.index.nprobe, self.index.nlist)
143154

144-
def write_checkpoint(self, fname, cur_list_no):
155+
def write_checkpoint(self, fname, completed):
145156
# write to temp file then move to final file
146157
tmpname = fname + ".tmp"
147-
pickle.dump(
148-
{
149-
"sizes": self.sizes_in_checkpoint(),
150-
"cur_list_no": cur_list_no,
151-
"rh": (self.rh.D, self.rh.I),
152-
}, open(tmpname, "wb"), -1
153-
)
158+
with open(tmpname, "wb") as f:
159+
pickle.dump(
160+
{
161+
"sizes": self.sizes_in_checkpoint(),
162+
"completed": completed,
163+
"rh": (self.rh.D, self.rh.I),
164+
}, f, -1)
154165
os.replace(tmpname, fname)
155166

156167
def read_checkpoint(self, fname):
157-
ckp = pickle.load(open(fname, "rb"))
168+
with open(fname, "rb") as f:
169+
ckp = pickle.load(f)
158170
assert ckp["sizes"] == self.sizes_in_checkpoint()
159171
self.rh.D[:] = ckp["rh"][0]
160172
self.rh.I[:] = ckp["rh"][1]
161-
return ckp["cur_list_no"]
173+
return ckp["completed"]
162174

163175

164176
class BlockComputer:
@@ -225,11 +237,11 @@ def big_batch_search(
225237
verbose=0,
226238
threaded=0,
227239
use_float16=False,
228-
prefetch_threads=8,
229-
computation_threads=0,
240+
prefetch_threads=1,
241+
computation_threads=1,
230242
q_assign=None,
231243
checkpoint=None,
232-
checkpoint_freq=64,
244+
checkpoint_freq=7200,
233245
start_list=0,
234246
end_list=None,
235247
crash_at=-1
@@ -251,7 +263,7 @@ def big_batch_search(
251263
252264
threaded=0: sequential execution
253265
threaded=1: prefetch next bucket while computing the current one
254-
threaded>1: prefetch this many buckets at a time.
266+
threaded=2: prefetch prefetch_threads buckets at a time.
255267
256268
compute_threads>1: the knn function will get an additional thread_no that
257269
tells which worker should handle this.
@@ -311,12 +323,13 @@ def big_batch_search(
311323
if end_list is None:
312324
end_list = index.nlist
313325

326+
completed = set()
314327
if checkpoint is not None:
315328
assert (start_list, end_list) == (0, index.nlist)
316329
if os.path.exists(checkpoint):
317330
print("recovering checkpoint", checkpoint)
318-
start_list = bbs.read_checkpoint(checkpoint)
319-
print(" start at list", start_list)
331+
completed = bbs.read_checkpoint(checkpoint)
332+
print(" already completed", len(completed))
320333
else:
321334
print("no checkpoint: starting from scratch")
322335

@@ -363,94 +376,130 @@ def add_results_and_prefetch(to_add, l):
363376
bbs.add_results_to_heap(*to_add)
364377
pool.close()
365378
else:
366-
# run by batches with parallel prefetch and parallel comp
367-
list_step = threaded
368-
assert start_list % list_step == 0
369379

370-
if prefetch_threads == 0:
371-
prefetch_map = map
372-
else:
373-
prefetch_pool = ThreadPool(prefetch_threads)
374-
prefetch_map = prefetch_pool.map
375-
376-
if computation_threads > 0:
377-
comp_pool = ThreadPool(computation_threads)
378-
379-
def add_results_and_prefetch_batch(to_add, l):
380-
def add_results(to_add):
381-
for ta in to_add: # this one cannot be run in parallel...
382-
if ta is not None:
383-
bbs.add_results_to_heap(*ta)
384-
if prefetch_threads == 0:
385-
add_results(to_add)
386-
else:
387-
add_a = prefetch_pool.apply_async(add_results, (to_add, ))
388-
next_lists = range(l, min(l + list_step, index.nlist))
389-
res = list(prefetch_map(bbs.prepare_bucket, next_lists))
390-
if prefetch_threads > 0:
391-
add_a.get()
392-
return res
393-
394-
# used only when computation_threads > 1
395-
thread_id_to_seq_lock = threading.Lock()
396-
thread_id_to_seq = {}
397-
398-
def do_comp(bucket):
399-
(q_subset, xq_l, list_ids, xb_l) = bucket
380+
def task_manager_thread(
381+
task,
382+
pool_size,
383+
start_task,
384+
end_task,
385+
completed,
386+
output_queue,
387+
input_queue,
388+
):
400389
try:
401-
tid = thread_id_to_seq[threading.get_ident()]
402-
except KeyError:
403-
with thread_id_to_seq_lock:
404-
tid = len(thread_id_to_seq)
405-
thread_id_to_seq[threading.get_ident()] = tid
406-
D, I = comp.block_search(xq_l, xb_l, list_ids, k, thread_id=tid)
407-
return q_subset, D, list_ids, I
408-
409-
prefetched_buckets = add_results_and_prefetch_batch([], start_list)
410-
to_add = []
411-
pool = ThreadPool(1)
412-
prefetched_buckets_a = None
413-
414-
# loop over inverted lists
415-
for l in range(start_list, end_list, list_step):
416-
bbs.report(l)
417-
buckets = prefetched_buckets
418-
prefetched_buckets_a = pool.apply_async(
419-
add_results_and_prefetch_batch, (to_add, l + list_step))
420-
421-
bbs.start_t_accu()
422-
423-
to_add = []
424-
if computation_threads == 0:
425-
for q_subset, xq_l, list_ids, xb_l in buckets:
426-
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
427-
to_add.append((q_subset, D, list_ids, I))
428-
else:
429-
to_add = list(comp_pool.map(do_comp, buckets))
430-
431-
bbs.stop_t_accu(2)
390+
with ThreadPool(pool_size) as pool:
391+
res = [pool.apply_async(
392+
task,
393+
args=(i, output_queue, input_queue))
394+
for i in range(start_task, end_task)
395+
if i not in completed]
396+
for r in res:
397+
r.get()
398+
pool.close()
399+
pool.join()
400+
output_queue.put(None)
401+
except:
402+
traceback.print_exc()
403+
_thread.interrupt_main()
404+
raise
405+
406+
def task_manager(*args):
407+
task_manager = threading.Thread(
408+
target=task_manager_thread,
409+
args=args,
410+
)
411+
task_manager.daemon = True
412+
task_manager.start()
413+
return task_manager
414+
415+
def prepare_task(task_id, output_queue, input_queue=None):
416+
try:
417+
# print(f"Prepare start: {task_id}")
418+
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id)
419+
output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l))
420+
# print(f"Prepare end: {task_id}")
421+
except:
422+
traceback.print_exc()
423+
_thread.interrupt_main()
424+
raise
425+
426+
def compute_task(task_id, output_queue, input_queue):
427+
try:
428+
# print(f"Compute start: {task_id}")
429+
t_wait = 0
430+
while True:
431+
t0 = time.time()
432+
input_value = input_queue.get()
433+
t_wait += time.time() - t0
434+
if input_value is None:
435+
# signal for other compute tasks
436+
input_queue.put(None)
437+
break
438+
centroid, q_subset, xq_l, list_ids, xb_l = input_value
439+
# print(f'Compute work start: task {task_id}, centroid {centroid}')
440+
t0 = time.time()
441+
if computation_threads > 1:
442+
D, I = comp.block_search(
443+
xq_l, xb_l, list_ids, k, thread_id=task_id
444+
)
445+
else:
446+
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
447+
t_compute = time.time() - t0
448+
# print(f'Compute work end: task {task_id}, centroid {centroid}')
449+
t0 = time.time()
450+
output_queue.put(
451+
(centroid, t_wait, t_compute, q_subset, D, list_ids, I)
452+
)
453+
t_wait = time.time() - t0
454+
# print(f"Compute end: {task_id}")
455+
except:
456+
traceback.print_exc()
457+
_thread.interrupt_main()
458+
raise
459+
460+
prepare_to_compute_queue = Queue(2)
461+
compute_to_main_queue = Queue(2)
462+
compute_task_manager = task_manager(
463+
compute_task,
464+
computation_threads,
465+
0,
466+
computation_threads,
467+
set(),
468+
compute_to_main_queue,
469+
prepare_to_compute_queue,
470+
)
471+
prepare_task_manager = task_manager(
472+
prepare_task,
473+
prefetch_threads,
474+
start_list,
475+
end_list,
476+
completed,
477+
prepare_to_compute_queue,
478+
None,
479+
)
432480

481+
t_checkpoint = time.time()
482+
while True:
483+
value = compute_to_main_queue.get()
484+
if not value:
485+
break
486+
centroid, t_wait, t_compute, q_subset, D, list_ids, I = value
433487
# to test checkpointing
434-
if l == crash_at:
488+
if centroid == crash_at:
435489
1 / 0
436-
437-
bbs.start_t_accu()
438-
prefetched_buckets = prefetched_buckets_a.get()
439-
bbs.stop_t_accu(4)
440-
490+
bbs.t_accu[2] += t_compute
491+
bbs.t_accu[4] += t_wait
492+
bbs.add_results_to_heap(q_subset, D, list_ids, I)
493+
completed.add(centroid)
494+
bbs.report(centroid)
441495
if checkpoint is not None:
442-
if (l // list_step) % checkpoint_freq == 0:
443-
print("writing checkpoint %s" % l)
444-
bbs.write_checkpoint(checkpoint, l)
496+
if time.time() - t_checkpoint > checkpoint_freq:
497+
print("writing checkpoint")
498+
bbs.write_checkpoint(checkpoint, completed)
499+
t_checkpoint = time.time()
445500

446-
# flush add
447-
for ta in to_add:
448-
bbs.add_results_to_heap(*ta)
449-
pool.close()
450-
if prefetch_threads != 0:
451-
prefetch_pool.close()
452-
if computation_threads != 0:
453-
comp_pool.close()
501+
prepare_task_manager.join()
502+
compute_task_manager.join()
454503

455504
bbs.tic("finalize heap")
456505
bbs.rh.finalize()

tests/test_contrib.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import platform
1010
import os
1111
import random
12+
import tempfile
1213

1314
from faiss.contrib import datasets
1415
from faiss.contrib import inspect_tools
@@ -507,7 +508,7 @@ def do_test(self, factory_string, metric=faiss.METRIC_L2):
507508
Dref, Iref = index.search(ds.get_queries(), k)
508509
# faiss.omp_set_num_threads(1)
509510
for method in ("pairwise_distances", "knn_function", "index"):
510-
for threaded in 0, 1, 3, 8:
511+
for threaded in 0, 1, 2:
511512
Dnew, Inew = big_batch_search.big_batch_search(
512513
index, ds.get_queries(),
513514
k, method=method,
@@ -537,16 +538,15 @@ def test_checkpoint(self):
537538
index.nprobe = 5
538539
Dref, Iref = index.search(ds.get_queries(), k)
539540

540-
r = random.randrange(1 << 60)
541-
checkpoint = "/tmp/test_big_batch_checkpoint.%d" % r
541+
checkpoint = tempfile.mktemp()
542542
try:
543543
# First big batch search
544544
try:
545545
Dnew, Inew = big_batch_search.big_batch_search(
546546
index, ds.get_queries(),
547547
k, method="knn_function",
548-
threaded=4,
549-
checkpoint=checkpoint, checkpoint_freq=4,
548+
threaded=2,
549+
checkpoint=checkpoint, checkpoint_freq=0.1,
550550
crash_at=20
551551
)
552552
except ZeroDivisionError:
@@ -557,8 +557,8 @@ def test_checkpoint(self):
557557
Dnew, Inew = big_batch_search.big_batch_search(
558558
index, ds.get_queries(),
559559
k, method="knn_function",
560-
threaded=4,
561-
checkpoint=checkpoint, checkpoint_freq=4
560+
threaded=2,
561+
checkpoint=checkpoint, checkpoint_freq=5
562562
)
563563
self.assertLess((Inew != Iref).sum() / Iref.size, 1e-4)
564564
np.testing.assert_almost_equal(Dnew, Dref, decimal=4)

0 commit comments

Comments
 (0)