Skip to content

Commit 229a8aa

Browse files
bshethmetafacebook-github-bot
authored andcommitted
Prevent reordering of imports by auto formatter to avoid crashes
Summary: Apparently this is the generally accepted way to do this. https://usort.readthedocs.io/en/stable/guide.html#import-blocks What do you tihnk? Differential Revision: D62147522
1 parent 4683cc1 commit 229a8aa

File tree

2 files changed

+27
-31
lines changed

2 files changed

+27
-31
lines changed

tests/torch_test_contrib.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import torch
7-
import unittest
8-
import numpy as np
9-
import faiss
10-
import faiss.contrib.torch_utils
6+
import torch # usort: skip
7+
import unittest # usort: skip
8+
import numpy as np # usort: skip
9+
import faiss # usort: skip
10+
import faiss.contrib.torch_utils # usort: skip
11+
1112

1213
class TestTorchUtilsCPU(unittest.TestCase):
1314
# tests add, search
@@ -141,7 +142,7 @@ def test_assign(self):
141142
self.assertTrue(np.array_equal(labels, labels_ref))
142143

143144
# Test assign with numpy output provided
144-
labels = np.empty((xq.shape[0], 5), dtype='int64')
145+
labels = np.empty((xq.shape[0], 5), dtype="int64")
145146
index.assign(xq.numpy(), 5, labels)
146147
self.assertTrue(np.array_equal(labels, labels_ref))
147148

tests/torch_test_neural_net.py

+20-25
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,25 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import torch
7-
from torch import nn
8-
import unittest
9-
import numpy as np
10-
import faiss
6+
import torch # usort: skip
7+
from torch import nn # usort: skip
8+
import unittest # usort: skip
9+
import numpy as np # usort: skip
10+
import faiss # usort: skip
1111

12-
from faiss.contrib import datasets
13-
from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks
12+
from faiss.contrib import datasets # usort: skip
13+
from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks # usort: skip
1414

1515

1616
class TestLayer(unittest.TestCase):
1717

1818
@torch.no_grad()
1919
def test_Embedding(self):
20-
""" verify that the Faiss Embedding works the same as in Pytorch """
20+
"""verify that the Faiss Embedding works the same as in Pytorch"""
2121
torch.manual_seed(123)
2222

2323
emb = nn.Embedding(40, 50)
24-
idx = torch.randint(40, (25, ))
24+
idx = torch.randint(40, (25,))
2525
ref_batch = emb(idx)
2626

2727
emb2 = faiss.Embedding(emb)
@@ -33,7 +33,7 @@ def test_Embedding(self):
3333

3434
@torch.no_grad()
3535
def do_test_Linear(self, bias):
36-
""" verify that the Faiss Linear works the same as in Pytorch """
36+
"""verify that the Faiss Linear works the same as in Pytorch"""
3737
torch.manual_seed(123)
3838
linear = nn.Linear(50, 40, bias=bias)
3939
x = torch.randn(25, 50)
@@ -50,6 +50,7 @@ def test_Linear(self):
5050
def test_Linear_nobias(self):
5151
self.do_test_Linear(False)
5252

53+
5354
######################################################
5455
# QINCo Pytorch implementation copied from
5556
# https://github.com/facebookresearch/Qinco/blob/main/model_qinco.py
@@ -219,6 +220,7 @@ def encode(self, x, code0=None):
219220
# QINCo tests
220221
######################################################
221222

223+
222224
def copy_QINCoStep(step):
223225
step2 = faiss.QINCoStep(step.d, step.K, step.L, step.h)
224226
step2.codebook.from_torch(step.codebook)
@@ -238,7 +240,7 @@ def test_decode(self):
238240
torch.manual_seed(123)
239241
step = QINCoStep(d=16, K=20, L=2, h=8)
240242

241-
codes = torch.randint(0, 20, (10, ))
243+
codes = torch.randint(0, 20, (10,))
242244
xhat = torch.randn(10, 16)
243245
ref_decode = step.decode(xhat, codes)
244246

@@ -247,28 +249,23 @@ def test_decode(self):
247249
codes2 = faiss.Int32Tensor2D(codes[:, None].to(dtype=torch.int32))
248250

249251
np.testing.assert_array_equal(
250-
step.codebook(codes).numpy(),
251-
step2.codebook(codes2).numpy()
252+
step.codebook(codes).numpy(), step2.codebook(codes2).numpy()
252253
)
253254

254255
xhat2 = faiss.Tensor2D(xhat)
255256
# xhat2 = faiss.Tensor2D(len(codes), step2.d)
256257

257258
new_decode = step2.decode(xhat2, codes2)
258259

259-
np.testing.assert_allclose(
260-
ref_decode.numpy(),
261-
new_decode.numpy(),
262-
atol=2e-6
263-
)
260+
np.testing.assert_allclose(ref_decode.numpy(), new_decode.numpy(), atol=2e-6)
264261

265262
@torch.no_grad()
266263
def test_encode(self):
267264
torch.manual_seed(123)
268265
step = QINCoStep(d=16, K=20, L=2, h=8)
269266

270267
# create plausible x for testing starting from actual codes
271-
codes = torch.randint(0, 20, (10, ))
268+
codes = torch.randint(0, 20, (10,))
272269
xhat = torch.zeros(10, 16)
273270
x = step.decode(xhat, codes)
274271
del codes
@@ -282,14 +279,11 @@ def test_encode(self):
282279
new_codes = step2.encode(xhat2, x2, toadd2)
283280

284281
np.testing.assert_allclose(
285-
ref_codes.numpy(),
286-
new_codes.numpy().ravel(),
287-
atol=2e-6
282+
ref_codes.numpy(), new_codes.numpy().ravel(), atol=2e-6
288283
)
289284
np.testing.assert_allclose(toadd.numpy(), toadd2.numpy(), atol=2e-6)
290285

291286

292-
293287
class TestQINCo(unittest.TestCase):
294288

295289
@torch.no_grad()
@@ -327,11 +321,12 @@ def test_encode(self):
327321
# Test index
328322
######################################################
329323

324+
330325
class TestIndexQINCo(unittest.TestCase):
331326

332327
def test_search(self):
333328
"""
334-
We can't train qinco with just Faiss so we just train a RQ and use the
329+
We can't train qinco with just Faiss so we just train a RQ and use the
335330
codebooks in QINCo with L = 0 residual blocks
336331
"""
337332
ds = datasets.SyntheticDataset(32, 1000, 100, 0)
@@ -342,7 +337,7 @@ def test_search(self):
342337
rq = index_ref.rq
343338
# rq = faiss.ResidualQuantizer(ds.d, M, 4)
344339
rq.train_type = faiss.ResidualQuantizer.Train_default
345-
rq.max_beam_size = 1 # beam search not implemented for QINCo (yet)
340+
rq.max_beam_size = 1 # beam search not implemented for QINCo (yet)
346341
index_ref.train(ds.get_train())
347342
codebooks = get_additive_quantizer_codebooks(rq)
348343

0 commit comments

Comments
 (0)