@@ -544,6 +544,7 @@ def replacement_range_search(self, x, thresh, *, params=None):
544
544
n , d = x .shape
545
545
assert d == self .d
546
546
x = np .ascontiguousarray (x , dtype = 'float32' )
547
+ thresh = float (thresh )
547
548
548
549
res = RangeSearchResult (n )
549
550
self .range_search_c (n , swig_ptr (x ), thresh , res , params )
@@ -618,6 +619,64 @@ def replacement_search_preassigned(self, x, k, Iq, Dq, *, params=None, D=None, I
618
619
)
619
620
return D , I
620
621
622
+ def replacement_range_search_preassigned (self , x , thresh , Iq , Dq , * , params = None ):
623
+ """Search vectors that are within a distance of the query vectors.
624
+
625
+ Parameters
626
+ ----------
627
+ x : array_like
628
+ Query vectors, shape (n, d) where d is appropriate for the index.
629
+ `dtype` must be float32.
630
+ thresh : float
631
+ Threshold to select neighbors. All elements within this radius are returned,
632
+ except for maximum inner product indexes, where the elements above the
633
+ threshold are returned
634
+ Iq : array_like, optional
635
+ Nearest centroids, size (n, nprobe)
636
+ Dq : array_like, optional
637
+ Distance array to the centroids, size (n, nprobe)
638
+ params : SearchParameters
639
+ Search parameters of the current search (overrides the class-level params)
640
+
641
+
642
+ Returns
643
+ -------
644
+ lims: array_like
645
+ Starting index of the results for each query vector, size n+1.
646
+ D : array_like
647
+ Distances of the nearest neighbors, shape `lims[n]`. The distances for
648
+ query i are in `D[lims[i]:lims[i+1]]`.
649
+ I : array_like
650
+ Labels of nearest neighbors, shape `lims[n]`. The labels for query i
651
+ are in `I[lims[i]:lims[i+1]]`.
652
+
653
+ """
654
+ n , d = x .shape
655
+ assert d == self .d
656
+ x = np .ascontiguousarray (x , dtype = 'float32' )
657
+
658
+ Iq = np .ascontiguousarray (Iq , dtype = 'int64' )
659
+ assert params is None , "params not supported"
660
+ assert Iq .shape == (n , self .nprobe )
661
+
662
+ if Dq is not None :
663
+ Dq = np .ascontiguousarray (Dq , dtype = 'float32' )
664
+ assert Dq .shape == Iq .shape
665
+
666
+ thresh = float (thresh )
667
+ res = RangeSearchResult (n )
668
+ self .range_search_preassigned_c (
669
+ n , swig_ptr (x ), thresh ,
670
+ swig_ptr (Iq ), swig_ptr (Dq ),
671
+ res
672
+ )
673
+ # get pointers and copy them
674
+ lims = rev_swig_ptr (res .lims , n + 1 ).copy ()
675
+ nd = int (lims [- 1 ])
676
+ D = rev_swig_ptr (res .distances , nd ).copy ()
677
+ I = rev_swig_ptr (res .labels , nd ).copy ()
678
+ return lims , D , I
679
+
621
680
def replacement_sa_encode (self , x , codes = None ):
622
681
n , d = x .shape
623
682
assert d == self .d
@@ -675,8 +734,12 @@ def replacement_permute_entries(self, perm):
675
734
ignore_missing = True )
676
735
replace_method (the_class , 'search_and_reconstruct' ,
677
736
replacement_search_and_reconstruct , ignore_missing = True )
737
+
738
+ # these ones are IVF-specific
678
739
replace_method (the_class , 'search_preassigned' ,
679
740
replacement_search_preassigned , ignore_missing = True )
741
+ replace_method (the_class , 'range_search_preassigned' ,
742
+ replacement_range_search_preassigned , ignore_missing = True )
680
743
replace_method (the_class , 'sa_encode' , replacement_sa_encode )
681
744
replace_method (the_class , 'sa_decode' , replacement_sa_decode )
682
745
replace_method (the_class , 'add_sa_codes' , replacement_add_sa_codes ,
@@ -776,6 +839,36 @@ def replacement_range_search(self, x, thresh):
776
839
I = rev_swig_ptr (res .labels , nd ).copy ()
777
840
return lims , D , I
778
841
842
+ def replacement_range_search_preassigned (self , x , thresh , Iq , Dq , * , params = None ):
843
+ n , d = x .shape
844
+ x = _check_dtype_uint8 (x )
845
+ assert d * 8 == self .d
846
+
847
+ Iq = np .ascontiguousarray (Iq , dtype = 'int64' )
848
+ assert params is None , "params not supported"
849
+ assert Iq .shape == (n , self .nprobe )
850
+
851
+ if Dq is not None :
852
+ Dq = np .ascontiguousarray (Dq , dtype = 'int32' )
853
+ assert Dq .shape == Iq .shape
854
+
855
+ thresh = int (thresh )
856
+ res = RangeSearchResult (n )
857
+ self .range_search_preassigned_c (
858
+ n , swig_ptr (x ), thresh ,
859
+ swig_ptr (Iq ), swig_ptr (Dq ),
860
+ res
861
+ )
862
+ # get pointers and copy them
863
+ lims = rev_swig_ptr (res .lims , n + 1 ).copy ()
864
+ nd = int (lims [- 1 ])
865
+ D = rev_swig_ptr (res .distances , nd ).copy ()
866
+ I = rev_swig_ptr (res .labels , nd ).copy ()
867
+ return lims , D , I
868
+
869
+
870
+
871
+
779
872
def replacement_remove_ids (self , x ):
780
873
if isinstance (x , IDSelector ):
781
874
sel = x
@@ -794,6 +887,8 @@ def replacement_remove_ids(self, x):
794
887
replace_method (the_class , 'remove_ids' , replacement_remove_ids )
795
888
replace_method (the_class , 'search_preassigned' ,
796
889
replacement_search_preassigned , ignore_missing = True )
890
+ replace_method (the_class , 'range_search_preassigned' ,
891
+ replacement_range_search_preassigned , ignore_missing = True )
797
892
798
893
799
894
def handle_VectorTransform (the_class ):
@@ -937,7 +1032,7 @@ def handle_MapLong2Long(the_class):
937
1032
938
1033
def replacement_map_add (self , keys , vals ):
939
1034
n , = keys .shape
940
- assert (n ,) == keys .shape
1035
+ assert (n ,) == vals .shape
941
1036
self .add_c (n , swig_ptr (keys ), swig_ptr (vals ))
942
1037
943
1038
def replacement_map_search_multiple (self , keys ):
0 commit comments