28
28
import sys
29
29
import numpy as np
30
30
31
+ ##################################################################
32
+ # Equivalent of swig_ptr for Torch tensors
33
+ ##################################################################
34
+
31
35
def swig_ptr_from_UInt8Tensor (x ):
32
36
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
33
37
assert x .is_contiguous ()
34
38
assert x .dtype == torch .uint8
35
39
return faiss .cast_integer_to_uint8_ptr (
36
40
x .untyped_storage ().data_ptr () + x .storage_offset ())
37
41
42
+
38
43
def swig_ptr_from_HalfTensor (x ):
39
44
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
40
45
assert x .is_contiguous ()
@@ -43,27 +48,34 @@ def swig_ptr_from_HalfTensor(x):
43
48
return faiss .cast_integer_to_void_ptr (
44
49
x .untyped_storage ().data_ptr () + x .storage_offset () * 2 )
45
50
51
+
46
52
def swig_ptr_from_FloatTensor (x ):
47
53
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
48
54
assert x .is_contiguous ()
49
55
assert x .dtype == torch .float32
50
56
return faiss .cast_integer_to_float_ptr (
51
57
x .untyped_storage ().data_ptr () + x .storage_offset () * 4 )
52
58
59
+
53
60
def swig_ptr_from_IntTensor (x ):
54
61
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
55
62
assert x .is_contiguous ()
56
63
assert x .dtype == torch .int32 , 'dtype=%s' % x .dtype
57
64
return faiss .cast_integer_to_int_ptr (
58
65
x .untyped_storage ().data_ptr () + x .storage_offset () * 4 )
59
66
67
+
60
68
def swig_ptr_from_IndicesTensor (x ):
61
69
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
62
70
assert x .is_contiguous ()
63
71
assert x .dtype == torch .int64 , 'dtype=%s' % x .dtype
64
72
return faiss .cast_integer_to_idx_t_ptr (
65
73
x .untyped_storage ().data_ptr () + x .storage_offset () * 8 )
66
74
75
+ ##################################################################
76
+ # utilities
77
+ ##################################################################
78
+
67
79
@contextlib .contextmanager
68
80
def using_stream (res , pytorch_stream = None ):
69
81
""" Creates a scoping object to make Faiss GPU use the same stream
@@ -107,6 +119,10 @@ def torch_replace_method(the_class, name, replacement,
107
119
setattr (the_class , name + '_numpy' , orig_method )
108
120
setattr (the_class , name , replacement )
109
121
122
+ ##################################################################
123
+ # Setup wrappers
124
+ ##################################################################
125
+
110
126
def handle_torch_Index (the_class ):
111
127
def torch_replacement_add (self , x ):
112
128
if type (x ) is np .ndarray :
@@ -493,6 +509,52 @@ def torch_replacement_sa_decode(self, codes, x=None):
493
509
handle_torch_Index (the_class )
494
510
495
511
512
+ # allows torch tensor usage with knn
513
+ def torch_replacement_knn (xq , xb , k , metric = faiss .METRIC_L2 , metric_arg = 0 ):
514
+ if type (xb ) is np .ndarray :
515
+ # Forward to faiss __init__.py base method
516
+ return faiss .knn_numpy (xq , xb , k , metric = metric , metric_arg = metric_arg )
517
+
518
+ nb , d = xb .size ()
519
+ assert xb .is_contiguous ()
520
+ assert xb .dtype == torch .float32
521
+ assert not xb .is_cuda , "use knn_gpu for GPU tensors"
522
+
523
+ nq , d2 = xq .size ()
524
+ assert d2 == d
525
+ assert xq .is_contiguous ()
526
+ assert xq .dtype == torch .float32
527
+ assert not xq .is_cuda , "use knn_gpu for GPU tensors"
528
+
529
+ D = torch .empty (nq , k , device = xb .device , dtype = torch .float32 )
530
+ I = torch .empty (nq , k , device = xb .device , dtype = torch .int64 )
531
+ I_ptr = swig_ptr_from_IndicesTensor (I )
532
+ D_ptr = swig_ptr_from_FloatTensor (D )
533
+ xb_ptr = swig_ptr_from_FloatTensor (xb )
534
+ xq_ptr = swig_ptr_from_FloatTensor (xq )
535
+
536
+ if metric == faiss .METRIC_L2 :
537
+ faiss .knn_L2sqr (
538
+ xq_ptr , xb_ptr ,
539
+ d , nq , nb , k , D_ptr , I_ptr
540
+ )
541
+ elif metric == faiss .METRIC_INNER_PRODUCT :
542
+ faiss .knn_inner_product (
543
+ xq_ptr , xb_ptr ,
544
+ d , nq , nb , k , D_ptr , I_ptr
545
+ )
546
+ else :
547
+ faiss .knn_extra_metrics (
548
+ xq_ptr , xb_ptr ,
549
+ d , nq , nb , metric , metric_arg , k , D_ptr , I_ptr
550
+ )
551
+
552
+ return D , I
553
+
554
+
555
+ torch_replace_method (faiss_module , 'knn' , torch_replacement_knn , True , True )
556
+
557
+
496
558
# allows torch tensor usage with bfKnn
497
559
def torch_replacement_knn_gpu (res , xq , xb , k , D = None , I = None , metric = faiss .METRIC_L2 , device = - 1 , use_raft = False ):
498
560
if type (xb ) is np .ndarray :
0 commit comments