7
7
from copy import copy
8
8
from dataclasses import dataclass
9
9
from operator import itemgetter
10
- from statistics import median , mean
10
+ from statistics import mean , median
11
11
from typing import Any , Dict , List , Optional
12
12
13
- from .utils import dict_merge
14
- from .index import Index , IndexFromCodec , IndexFromFactory
15
- from .descriptors import DatasetDescriptor , IndexDescriptor
16
-
17
13
import faiss # @manual=//faiss/python:pyfaiss_gpu
18
14
19
15
import numpy as np
20
16
21
17
from scipy .optimize import curve_fit
22
18
19
+ from .descriptors import DatasetDescriptor , IndexDescriptor
20
+ from .index import Index , IndexFromCodec , IndexFromFactory
21
+
22
+ from .utils import dict_merge
23
+
23
24
logger = logging .getLogger (__name__ )
24
25
25
26
@@ -274,8 +275,8 @@ def range_search(
274
275
search_parameters : Optional [Dict [str , int ]],
275
276
radius : Optional [float ] = None ,
276
277
gt_radius : Optional [float ] = None ,
277
- range_search_metric_function = None ,
278
- gt_rsm = None ,
278
+ range_search_metric_function = None ,
279
+ gt_rsm = None ,
279
280
):
280
281
logger .info ("range_search: begin" )
281
282
if radius is None :
@@ -328,7 +329,13 @@ def knn_ground_truth(self):
328
329
logger .info ("knn_ground_truth: begin" )
329
330
flat_desc = self .get_index_desc ("Flat" )
330
331
self .build_index_wrapper (flat_desc )
331
- self .gt_knn_D , self .gt_knn_I , _ , _ , requires = flat_desc .index .knn_search (
332
+ (
333
+ self .gt_knn_D ,
334
+ self .gt_knn_I ,
335
+ _ ,
336
+ _ ,
337
+ requires ,
338
+ ) = flat_desc .index .knn_search (
332
339
dry_run = False ,
333
340
search_parameters = None ,
334
341
query_vectors = self .query_vectors ,
@@ -338,13 +345,13 @@ def knn_ground_truth(self):
338
345
logger .info ("knn_ground_truth: end" )
339
346
340
347
def search_benchmark (
341
- self ,
348
+ self ,
342
349
name ,
343
350
search_func ,
344
351
key_func ,
345
352
cost_metrics ,
346
353
perf_metrics ,
347
- results : Dict [str , Any ],
354
+ results : Dict [str , Any ],
348
355
index : Index ,
349
356
):
350
357
index_name = index .get_index_name ()
@@ -376,11 +383,18 @@ def experiment(parameters, cost_metric, perf_metric):
376
383
logger .info (f"{ name } _benchmark: end" )
377
384
return results , requires
378
385
379
- def knn_search_benchmark (self , dry_run , results : Dict [str , Any ], index : Index ):
386
+ def knn_search_benchmark (
387
+ self , dry_run , results : Dict [str , Any ], index : Index
388
+ ):
380
389
return self .search_benchmark (
381
390
name = "knn_search" ,
382
391
search_func = lambda parameters : index .knn_search (
383
- dry_run , parameters , self .query_vectors , self .k , self .gt_knn_I , self .gt_knn_D ,
392
+ dry_run ,
393
+ parameters ,
394
+ self .query_vectors ,
395
+ self .k ,
396
+ self .gt_knn_I ,
397
+ self .gt_knn_D ,
384
398
)[3 :],
385
399
key_func = lambda parameters : index .get_knn_search_name (
386
400
search_parameters = parameters ,
@@ -394,11 +408,17 @@ def knn_search_benchmark(self, dry_run, results: Dict[str, Any], index: Index):
394
408
index = index ,
395
409
)
396
410
397
- def reconstruct_benchmark (self , dry_run , results : Dict [str , Any ], index : Index ):
411
+ def reconstruct_benchmark (
412
+ self , dry_run , results : Dict [str , Any ], index : Index
413
+ ):
398
414
return self .search_benchmark (
399
415
name = "reconstruct" ,
400
416
search_func = lambda parameters : index .reconstruct (
401
- dry_run , parameters , self .query_vectors , self .k , self .gt_knn_I ,
417
+ dry_run ,
418
+ parameters ,
419
+ self .query_vectors ,
420
+ self .k ,
421
+ self .gt_knn_I ,
402
422
),
403
423
key_func = lambda parameters : index .get_knn_search_name (
404
424
search_parameters = parameters ,
@@ -426,31 +446,33 @@ def range_search_benchmark(
426
446
return self .search_benchmark (
427
447
name = "range_search" ,
428
448
search_func = lambda parameters : self .range_search (
429
- dry_run = dry_run ,
430
- index = index ,
431
- search_parameters = parameters ,
449
+ dry_run = dry_run ,
450
+ index = index ,
451
+ search_parameters = parameters ,
432
452
radius = radius ,
433
453
gt_radius = gt_radius ,
434
- range_search_metric_function = range_search_metric_function ,
454
+ range_search_metric_function = range_search_metric_function ,
435
455
gt_rsm = gt_rsm ,
436
456
)[4 :],
437
457
key_func = lambda parameters : index .get_range_search_name (
438
458
search_parameters = parameters ,
439
459
query_vectors = self .query_vectors ,
440
460
radius = radius ,
441
- ) + metric_key ,
461
+ )
462
+ + metric_key ,
442
463
cost_metrics = ["time" ],
443
464
perf_metrics = ["range_score_max_recall" ],
444
465
results = results ,
445
466
index = index ,
446
467
)
447
468
448
469
def build_index_wrapper (self , index_desc : IndexDescriptor ):
449
- if hasattr (index_desc , ' index' ):
470
+ if hasattr (index_desc , " index" ):
450
471
return
451
472
if index_desc .factory is not None :
452
473
training_vectors = copy (self .training_vectors )
453
- training_vectors .num_vectors = index_desc .training_size
474
+ if index_desc .training_size is not None :
475
+ training_vectors .num_vectors = index_desc .training_size
454
476
index = IndexFromFactory (
455
477
num_threads = self .num_threads ,
456
478
d = self .d ,
@@ -481,15 +503,24 @@ def clone_one(self, index_desc):
481
503
training_vectors = self .training_vectors ,
482
504
database_vectors = self .database_vectors ,
483
505
query_vectors = self .query_vectors ,
484
- index_descs = [self .get_index_desc ("Flat" ), index_desc ],
506
+ index_descs = [self .get_index_desc ("Flat" ), index_desc ],
485
507
range_ref_index_desc = self .range_ref_index_desc ,
486
508
k = self .k ,
487
509
distance_metric = self .distance_metric ,
488
510
)
489
- benchmark .set_io (self .io )
511
+ benchmark .set_io (self .io . clone () )
490
512
return benchmark
491
513
492
- def benchmark_one (self , dry_run , results : Dict [str , Any ], index_desc : IndexDescriptor , train , reconstruct , knn , range ):
514
+ def benchmark_one (
515
+ self ,
516
+ dry_run ,
517
+ results : Dict [str , Any ],
518
+ index_desc : IndexDescriptor ,
519
+ train ,
520
+ reconstruct ,
521
+ knn ,
522
+ range ,
523
+ ):
493
524
faiss .omp_set_num_threads (self .num_threads )
494
525
if not dry_run :
495
526
self .knn_ground_truth ()
@@ -531,9 +562,12 @@ def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescr
531
562
)
532
563
assert requires is None
533
564
534
- if self .range_ref_index_desc is None or not index_desc .index .supports_range_search ():
565
+ if (
566
+ self .range_ref_index_desc is None
567
+ or not index_desc .index .supports_range_search ()
568
+ ):
535
569
return results , None
536
-
570
+
537
571
ref_index_desc = self .get_index_desc (self .range_ref_index_desc )
538
572
if ref_index_desc is None :
539
573
raise ValueError (
@@ -550,7 +584,9 @@ def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescr
550
584
coefficients ,
551
585
coefficients_training_data ,
552
586
) = self .range_search_reference (
553
- ref_index_desc .index , ref_index_desc .search_params , range_metric
587
+ ref_index_desc .index ,
588
+ ref_index_desc .search_params ,
589
+ range_metric ,
554
590
)
555
591
gt_rsm = self .range_ground_truth (
556
592
gt_radius , range_search_metric_function
@@ -583,7 +619,15 @@ def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescr
583
619
584
620
return results , None
585
621
586
- def benchmark (self , result_file = None , local = False , train = False , reconstruct = False , knn = False , range = False ):
622
+ def benchmark (
623
+ self ,
624
+ result_file = None ,
625
+ local = False ,
626
+ train = False ,
627
+ reconstruct = False ,
628
+ knn = False ,
629
+ range = False ,
630
+ ):
587
631
logger .info ("begin evaluate" )
588
632
589
633
faiss .omp_set_num_threads (self .num_threads )
@@ -656,20 +700,34 @@ def benchmark(self, result_file=None, local=False, train=False, reconstruct=Fals
656
700
657
701
if current_todo :
658
702
results_one = {"indices" : {}, "experiments" : {}}
659
- params = [(self .clone_one (index_desc ), results_one , index_desc , train , reconstruct , knn , range ) for index_desc in current_todo ]
660
- for result in self .io .launch_jobs (run_benchmark_one , params , local = local ):
703
+ params = [
704
+ (
705
+ index_desc ,
706
+ self .clone_one (index_desc ),
707
+ results_one ,
708
+ train ,
709
+ reconstruct ,
710
+ knn ,
711
+ range ,
712
+ )
713
+ for index_desc in current_todo
714
+ ]
715
+ for result in self .io .launch_jobs (
716
+ run_benchmark_one , params , local = local
717
+ ):
661
718
dict_merge (results , result )
662
719
663
- todo = next_todo
720
+ todo = next_todo
664
721
665
722
if result_file is not None :
666
723
self .io .write_json (results , result_file , overwrite = True )
667
724
logger .info ("end evaluate" )
668
725
return results
669
726
727
+
670
728
def run_benchmark_one (params ):
671
729
logger .info (params )
672
- benchmark , results , index_desc , train , reconstruct , knn , range = params
730
+ index_desc , benchmark , results , train , reconstruct , knn , range = params
673
731
results , requires = benchmark .benchmark_one (
674
732
dry_run = False ,
675
733
results = results ,
0 commit comments