Skip to content

Commit 1d0e8d4

Browse files
algoriddlefacebook-github-bot
authored andcommitted
index optimizer (facebookresearch#3154)
Summary: Pull Request resolved: facebookresearch#3154 Using the benchmark to find Pareto optimal indices, in this case on BigANN as an example. Separately optimize the coarse quantizer and the vector codec and use Pareto optimal configurations to construct IVF indices, which are then retested at various scales. See `optimize()` in `optimize.py` as the main function driving the process. The results can be interpreted with `bench_fw_notebook.ipynb`, which allows: * filtering by maximum code size * maximum time * minimum accuracy * space or time Pareto optimal options * and visualize the results and output them as a table. This version is intentionally limited to IVF(Flat|HNSW),PQ|SQ indices... Reviewed By: mdouze Differential Revision: D51781670 fbshipit-source-id: 2c0f800d374ea845255934f519cc28095c00a51f
1 parent 75ae0bf commit 1d0e8d4

8 files changed

+1318
-750
lines changed

benchs/bench_fw/benchmark.py

+90-32
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,20 @@
77
from copy import copy
88
from dataclasses import dataclass
99
from operator import itemgetter
10-
from statistics import median, mean
10+
from statistics import mean, median
1111
from typing import Any, Dict, List, Optional
1212

13-
from .utils import dict_merge
14-
from .index import Index, IndexFromCodec, IndexFromFactory
15-
from .descriptors import DatasetDescriptor, IndexDescriptor
16-
1713
import faiss # @manual=//faiss/python:pyfaiss_gpu
1814

1915
import numpy as np
2016

2117
from scipy.optimize import curve_fit
2218

19+
from .descriptors import DatasetDescriptor, IndexDescriptor
20+
from .index import Index, IndexFromCodec, IndexFromFactory
21+
22+
from .utils import dict_merge
23+
2324
logger = logging.getLogger(__name__)
2425

2526

@@ -274,8 +275,8 @@ def range_search(
274275
search_parameters: Optional[Dict[str, int]],
275276
radius: Optional[float] = None,
276277
gt_radius: Optional[float] = None,
277-
range_search_metric_function = None,
278-
gt_rsm = None,
278+
range_search_metric_function=None,
279+
gt_rsm=None,
279280
):
280281
logger.info("range_search: begin")
281282
if radius is None:
@@ -328,7 +329,13 @@ def knn_ground_truth(self):
328329
logger.info("knn_ground_truth: begin")
329330
flat_desc = self.get_index_desc("Flat")
330331
self.build_index_wrapper(flat_desc)
331-
self.gt_knn_D, self.gt_knn_I, _, _, requires = flat_desc.index.knn_search(
332+
(
333+
self.gt_knn_D,
334+
self.gt_knn_I,
335+
_,
336+
_,
337+
requires,
338+
) = flat_desc.index.knn_search(
332339
dry_run=False,
333340
search_parameters=None,
334341
query_vectors=self.query_vectors,
@@ -338,13 +345,13 @@ def knn_ground_truth(self):
338345
logger.info("knn_ground_truth: end")
339346

340347
def search_benchmark(
341-
self,
348+
self,
342349
name,
343350
search_func,
344351
key_func,
345352
cost_metrics,
346353
perf_metrics,
347-
results: Dict[str, Any],
354+
results: Dict[str, Any],
348355
index: Index,
349356
):
350357
index_name = index.get_index_name()
@@ -376,11 +383,18 @@ def experiment(parameters, cost_metric, perf_metric):
376383
logger.info(f"{name}_benchmark: end")
377384
return results, requires
378385

379-
def knn_search_benchmark(self, dry_run, results: Dict[str, Any], index: Index):
386+
def knn_search_benchmark(
387+
self, dry_run, results: Dict[str, Any], index: Index
388+
):
380389
return self.search_benchmark(
381390
name="knn_search",
382391
search_func=lambda parameters: index.knn_search(
383-
dry_run, parameters, self.query_vectors, self.k, self.gt_knn_I, self.gt_knn_D,
392+
dry_run,
393+
parameters,
394+
self.query_vectors,
395+
self.k,
396+
self.gt_knn_I,
397+
self.gt_knn_D,
384398
)[3:],
385399
key_func=lambda parameters: index.get_knn_search_name(
386400
search_parameters=parameters,
@@ -394,11 +408,17 @@ def knn_search_benchmark(self, dry_run, results: Dict[str, Any], index: Index):
394408
index=index,
395409
)
396410

397-
def reconstruct_benchmark(self, dry_run, results: Dict[str, Any], index: Index):
411+
def reconstruct_benchmark(
412+
self, dry_run, results: Dict[str, Any], index: Index
413+
):
398414
return self.search_benchmark(
399415
name="reconstruct",
400416
search_func=lambda parameters: index.reconstruct(
401-
dry_run, parameters, self.query_vectors, self.k, self.gt_knn_I,
417+
dry_run,
418+
parameters,
419+
self.query_vectors,
420+
self.k,
421+
self.gt_knn_I,
402422
),
403423
key_func=lambda parameters: index.get_knn_search_name(
404424
search_parameters=parameters,
@@ -426,31 +446,33 @@ def range_search_benchmark(
426446
return self.search_benchmark(
427447
name="range_search",
428448
search_func=lambda parameters: self.range_search(
429-
dry_run=dry_run,
430-
index=index,
431-
search_parameters=parameters,
449+
dry_run=dry_run,
450+
index=index,
451+
search_parameters=parameters,
432452
radius=radius,
433453
gt_radius=gt_radius,
434-
range_search_metric_function=range_search_metric_function,
454+
range_search_metric_function=range_search_metric_function,
435455
gt_rsm=gt_rsm,
436456
)[4:],
437457
key_func=lambda parameters: index.get_range_search_name(
438458
search_parameters=parameters,
439459
query_vectors=self.query_vectors,
440460
radius=radius,
441-
) + metric_key,
461+
)
462+
+ metric_key,
442463
cost_metrics=["time"],
443464
perf_metrics=["range_score_max_recall"],
444465
results=results,
445466
index=index,
446467
)
447468

448469
def build_index_wrapper(self, index_desc: IndexDescriptor):
449-
if hasattr(index_desc, 'index'):
470+
if hasattr(index_desc, "index"):
450471
return
451472
if index_desc.factory is not None:
452473
training_vectors = copy(self.training_vectors)
453-
training_vectors.num_vectors = index_desc.training_size
474+
if index_desc.training_size is not None:
475+
training_vectors.num_vectors = index_desc.training_size
454476
index = IndexFromFactory(
455477
num_threads=self.num_threads,
456478
d=self.d,
@@ -481,15 +503,24 @@ def clone_one(self, index_desc):
481503
training_vectors=self.training_vectors,
482504
database_vectors=self.database_vectors,
483505
query_vectors=self.query_vectors,
484-
index_descs = [self.get_index_desc("Flat"), index_desc],
506+
index_descs=[self.get_index_desc("Flat"), index_desc],
485507
range_ref_index_desc=self.range_ref_index_desc,
486508
k=self.k,
487509
distance_metric=self.distance_metric,
488510
)
489-
benchmark.set_io(self.io)
511+
benchmark.set_io(self.io.clone())
490512
return benchmark
491513

492-
def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescriptor, train, reconstruct, knn, range):
514+
def benchmark_one(
515+
self,
516+
dry_run,
517+
results: Dict[str, Any],
518+
index_desc: IndexDescriptor,
519+
train,
520+
reconstruct,
521+
knn,
522+
range,
523+
):
493524
faiss.omp_set_num_threads(self.num_threads)
494525
if not dry_run:
495526
self.knn_ground_truth()
@@ -531,9 +562,12 @@ def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescr
531562
)
532563
assert requires is None
533564

534-
if self.range_ref_index_desc is None or not index_desc.index.supports_range_search():
565+
if (
566+
self.range_ref_index_desc is None
567+
or not index_desc.index.supports_range_search()
568+
):
535569
return results, None
536-
570+
537571
ref_index_desc = self.get_index_desc(self.range_ref_index_desc)
538572
if ref_index_desc is None:
539573
raise ValueError(
@@ -550,7 +584,9 @@ def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescr
550584
coefficients,
551585
coefficients_training_data,
552586
) = self.range_search_reference(
553-
ref_index_desc.index, ref_index_desc.search_params, range_metric
587+
ref_index_desc.index,
588+
ref_index_desc.search_params,
589+
range_metric,
554590
)
555591
gt_rsm = self.range_ground_truth(
556592
gt_radius, range_search_metric_function
@@ -583,7 +619,15 @@ def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescr
583619

584620
return results, None
585621

586-
def benchmark(self, result_file=None, local=False, train=False, reconstruct=False, knn=False, range=False):
622+
def benchmark(
623+
self,
624+
result_file=None,
625+
local=False,
626+
train=False,
627+
reconstruct=False,
628+
knn=False,
629+
range=False,
630+
):
587631
logger.info("begin evaluate")
588632

589633
faiss.omp_set_num_threads(self.num_threads)
@@ -656,20 +700,34 @@ def benchmark(self, result_file=None, local=False, train=False, reconstruct=Fals
656700

657701
if current_todo:
658702
results_one = {"indices": {}, "experiments": {}}
659-
params = [(self.clone_one(index_desc), results_one, index_desc, train, reconstruct, knn, range) for index_desc in current_todo]
660-
for result in self.io.launch_jobs(run_benchmark_one, params, local=local):
703+
params = [
704+
(
705+
index_desc,
706+
self.clone_one(index_desc),
707+
results_one,
708+
train,
709+
reconstruct,
710+
knn,
711+
range,
712+
)
713+
for index_desc in current_todo
714+
]
715+
for result in self.io.launch_jobs(
716+
run_benchmark_one, params, local=local
717+
):
661718
dict_merge(results, result)
662719

663-
todo = next_todo
720+
todo = next_todo
664721

665722
if result_file is not None:
666723
self.io.write_json(results, result_file, overwrite=True)
667724
logger.info("end evaluate")
668725
return results
669726

727+
670728
def run_benchmark_one(params):
671729
logger.info(params)
672-
benchmark, results, index_desc, train, reconstruct, knn, range = params
730+
index_desc, benchmark, results, train, reconstruct, knn, range = params
673731
results, requires = benchmark.benchmark_one(
674732
dry_run=False,
675733
results=results,

benchs/bench_fw/benchmark_io.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
import os
1111
import pickle
1212
from dataclasses import dataclass
13-
import submitit
1413
from typing import Any, List, Optional
1514
from zipfile import ZipFile
1615

1716
import faiss # @manual=//faiss/python:pyfaiss_gpu
1817

1918
import numpy as np
19+
import submitit
2020
from faiss.contrib.datasets import ( # @manual=//faiss/contrib:faiss_contrib_gpu
2121
dataset_from_name,
2222
)
@@ -47,6 +47,9 @@ def merge_rcq_itq(
4747
class BenchmarkIO:
4848
path: str
4949

50+
def clone(self):
51+
return BenchmarkIO(path=self.path)
52+
5053
def __post_init__(self):
5154
self.cached_ds = {}
5255

@@ -119,18 +122,27 @@ def write_file(
119122

120123
def get_dataset(self, dataset):
121124
if dataset not in self.cached_ds:
122-
if dataset.namespace is not None and dataset.namespace[:4] == "std_":
125+
if (
126+
dataset.namespace is not None
127+
and dataset.namespace[:4] == "std_"
128+
):
123129
if dataset.tablename not in self.cached_ds:
124130
self.cached_ds[dataset.tablename] = dataset_from_name(
125131
dataset.tablename,
126132
)
127133
p = dataset.namespace[4]
128134
if p == "t":
129-
self.cached_ds[dataset] = self.cached_ds[dataset.tablename].get_train(dataset.num_vectors)
135+
self.cached_ds[dataset] = self.cached_ds[
136+
dataset.tablename
137+
].get_train(dataset.num_vectors)
130138
elif p == "d":
131-
self.cached_ds[dataset] = self.cached_ds[dataset.tablename].get_database()
139+
self.cached_ds[dataset] = self.cached_ds[
140+
dataset.tablename
141+
].get_database()
132142
elif p == "q":
133-
self.cached_ds[dataset] = self.cached_ds[dataset.tablename].get_queries()
143+
self.cached_ds[dataset] = self.cached_ds[
144+
dataset.tablename
145+
].get_queries()
134146
else:
135147
raise ValueError
136148
elif dataset.namespace == "syn":
@@ -233,8 +245,8 @@ def launch_jobs(self, func, params, local=True):
233245
if local:
234246
results = [func(p) for p in params]
235247
return results
236-
print(f'launching {len(params)} jobs')
237-
executor = submitit.AutoExecutor(folder='/checkpoint/gsz/jobs')
248+
logger.info(f"launching {len(params)} jobs")
249+
executor = submitit.AutoExecutor(folder="/checkpoint/gsz/jobs")
238250
executor.update_parameters(
239251
nodes=1,
240252
gpus_per_node=8,
@@ -248,9 +260,9 @@ def launch_jobs(self, func, params, local=True):
248260
slurm_constraint="bldg1",
249261
)
250262
jobs = executor.map_array(func, params)
251-
print(f'launched {len(jobs)} jobs')
252-
# for job, param in zip(jobs, params):
253-
# print(f"{job.job_id=} {param=}")
263+
logger.info(f"launched {len(jobs)} jobs")
264+
for job, param in zip(jobs, params):
265+
logger.info(f"{job.job_id=} {param[0]=}")
254266
results = [job.result() for job in jobs]
255-
print(f'received {len(results)} results')
267+
print(f"received {len(results)} results")
256268
return results

benchs/bench_fw/descriptors.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import faiss # @manual=//faiss/python:pyfaiss_gpu
1111
from .utils import timer
12+
1213
logger = logging.getLogger(__name__)
1314

1415

@@ -101,7 +102,9 @@ def k_means(self, io, k, dry_run):
101102
tablename=f"{self.get_filename()}kmeans_{k}.npy"
102103
)
103104
meta_filename = kmeans_vectors.tablename + ".json"
104-
if not io.file_exist(kmeans_vectors.tablename) or not io.file_exist(meta_filename):
105+
if not io.file_exist(kmeans_vectors.tablename) or not io.file_exist(
106+
meta_filename
107+
):
105108
if dry_run:
106109
return None, None, kmeans_vectors.tablename
107110
x = io.get_dataset(self)

0 commit comments

Comments
 (0)