3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
- import torch
7
- from torch import nn
8
- import unittest
9
- import numpy as np
10
- import faiss
6
+ import torch # usort: skip
7
+ from torch import nn # usort: skip
8
+ import unittest # usort: skip
9
+ import numpy as np # usort: skip
10
+ import faiss # usort: skip
11
11
12
- from faiss .contrib import datasets
13
- from faiss .contrib .inspect_tools import get_additive_quantizer_codebooks
12
+ from faiss .contrib import datasets # usort: skip
13
+ from faiss .contrib .inspect_tools import get_additive_quantizer_codebooks # usort: skip
14
14
15
15
16
16
class TestLayer (unittest .TestCase ):
17
17
18
18
@torch .no_grad ()
19
19
def test_Embedding (self ):
20
- """ verify that the Faiss Embedding works the same as in Pytorch """
20
+ """verify that the Faiss Embedding works the same as in Pytorch"""
21
21
torch .manual_seed (123 )
22
22
23
23
emb = nn .Embedding (40 , 50 )
24
- idx = torch .randint (40 , (25 , ))
24
+ idx = torch .randint (40 , (25 ,))
25
25
ref_batch = emb (idx )
26
26
27
27
emb2 = faiss .Embedding (emb )
@@ -33,7 +33,7 @@ def test_Embedding(self):
33
33
34
34
@torch .no_grad ()
35
35
def do_test_Linear (self , bias ):
36
- """ verify that the Faiss Linear works the same as in Pytorch """
36
+ """verify that the Faiss Linear works the same as in Pytorch"""
37
37
torch .manual_seed (123 )
38
38
linear = nn .Linear (50 , 40 , bias = bias )
39
39
x = torch .randn (25 , 50 )
@@ -50,6 +50,7 @@ def test_Linear(self):
50
50
def test_Linear_nobias (self ):
51
51
self .do_test_Linear (False )
52
52
53
+
53
54
######################################################
54
55
# QINCo Pytorch implementation copied from
55
56
# https://github.com/facebookresearch/Qinco/blob/main/model_qinco.py
@@ -219,6 +220,7 @@ def encode(self, x, code0=None):
219
220
# QINCo tests
220
221
######################################################
221
222
223
+
222
224
def copy_QINCoStep (step ):
223
225
step2 = faiss .QINCoStep (step .d , step .K , step .L , step .h )
224
226
step2 .codebook .from_torch (step .codebook )
@@ -238,7 +240,7 @@ def test_decode(self):
238
240
torch .manual_seed (123 )
239
241
step = QINCoStep (d = 16 , K = 20 , L = 2 , h = 8 )
240
242
241
- codes = torch .randint (0 , 20 , (10 , ))
243
+ codes = torch .randint (0 , 20 , (10 ,))
242
244
xhat = torch .randn (10 , 16 )
243
245
ref_decode = step .decode (xhat , codes )
244
246
@@ -247,28 +249,23 @@ def test_decode(self):
247
249
codes2 = faiss .Int32Tensor2D (codes [:, None ].to (dtype = torch .int32 ))
248
250
249
251
np .testing .assert_array_equal (
250
- step .codebook (codes ).numpy (),
251
- step2 .codebook (codes2 ).numpy ()
252
+ step .codebook (codes ).numpy (), step2 .codebook (codes2 ).numpy ()
252
253
)
253
254
254
255
xhat2 = faiss .Tensor2D (xhat )
255
256
# xhat2 = faiss.Tensor2D(len(codes), step2.d)
256
257
257
258
new_decode = step2 .decode (xhat2 , codes2 )
258
259
259
- np .testing .assert_allclose (
260
- ref_decode .numpy (),
261
- new_decode .numpy (),
262
- atol = 2e-6
263
- )
260
+ np .testing .assert_allclose (ref_decode .numpy (), new_decode .numpy (), atol = 2e-6 )
264
261
265
262
@torch .no_grad ()
266
263
def test_encode (self ):
267
264
torch .manual_seed (123 )
268
265
step = QINCoStep (d = 16 , K = 20 , L = 2 , h = 8 )
269
266
270
267
# create plausible x for testing starting from actual codes
271
- codes = torch .randint (0 , 20 , (10 , ))
268
+ codes = torch .randint (0 , 20 , (10 ,))
272
269
xhat = torch .zeros (10 , 16 )
273
270
x = step .decode (xhat , codes )
274
271
del codes
@@ -282,14 +279,11 @@ def test_encode(self):
282
279
new_codes = step2 .encode (xhat2 , x2 , toadd2 )
283
280
284
281
np .testing .assert_allclose (
285
- ref_codes .numpy (),
286
- new_codes .numpy ().ravel (),
287
- atol = 2e-6
282
+ ref_codes .numpy (), new_codes .numpy ().ravel (), atol = 2e-6
288
283
)
289
284
np .testing .assert_allclose (toadd .numpy (), toadd2 .numpy (), atol = 2e-6 )
290
285
291
286
292
-
293
287
class TestQINCo (unittest .TestCase ):
294
288
295
289
@torch .no_grad ()
@@ -327,11 +321,12 @@ def test_encode(self):
327
321
# Test index
328
322
######################################################
329
323
324
+
330
325
class TestIndexQINCo (unittest .TestCase ):
331
326
332
327
def test_search (self ):
333
328
"""
334
- We can't train qinco with just Faiss so we just train a RQ and use the
329
+ We can't train qinco with just Faiss so we just train a RQ and use the
335
330
codebooks in QINCo with L = 0 residual blocks
336
331
"""
337
332
ds = datasets .SyntheticDataset (32 , 1000 , 100 , 0 )
@@ -342,7 +337,7 @@ def test_search(self):
342
337
rq = index_ref .rq
343
338
# rq = faiss.ResidualQuantizer(ds.d, M, 4)
344
339
rq .train_type = faiss .ResidualQuantizer .Train_default
345
- rq .max_beam_size = 1 # beam search not implemented for QINCo (yet)
340
+ rq .max_beam_size = 1 # beam search not implemented for QINCo (yet)
346
341
index_ref .train (ds .get_train ())
347
342
codebooks = get_additive_quantizer_codebooks (rq )
348
343
0 commit comments