Skip to content

Commit 99141e1

Browse files
Xiao Fuabhinavdangeti
Xiao Fu
authored andcommittedJul 12, 2024
Add python tutorial on different indexs refinement and respect accuracy measurement (facebookresearch#3480)
Summary: Pull Request resolved: facebookresearch#3480 This tutorial summarize the methods to construct different indexs for PQFastScan refinement. It shows how the choice can impact on accuracy. Reviewed By: junjieqi Differential Revision: D57799598 fbshipit-source-id: a75c52c60a5217366f3361676da8f03f0c4a9feb
1 parent 22232d7 commit 99141e1

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed
 

‎tutorial/python/9-RefineComparison.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import faiss
7+
8+
from faiss.contrib.evaluation import knn_intersection_measure
9+
from faiss.contrib import datasets
10+
11+
# 64-dim vectors, 50000 vectors in the training, 100000 in database,
12+
# 10000 in queries, dtype ('float32')
13+
ds = datasets.SyntheticDataset(64, 50000, 100000, 10000)
14+
d = 64 # dimension
15+
16+
# Constructing the refine PQ index with SQfp16 with index factory
17+
index_fp16 = faiss.index_factory(d, 'PQ32x4fs,Refine(SQfp16)')
18+
index_fp16.train(ds.get_train())
19+
index_fp16.add(ds.get_database())
20+
21+
# Constructing the refine PQ index with SQ8
22+
index_sq8 = faiss.index_factory(d, 'PQ32x4fs,Refine(SQ8)')
23+
index_sq8.train(ds.get_train())
24+
index_sq8.add(ds.get_database())
25+
26+
# Parameterization on k factor while doing search for index refinement
27+
k_factor = 3.0
28+
params = faiss.IndexRefineSearchParameters(k_factor=k_factor)
29+
30+
# Perform index search using different index refinement
31+
D_fp16, I_fp16 = index_fp16.search(ds.get_queries(), 100, params=params)
32+
D_sq8, I_sq8 = index_sq8.search(ds.get_queries(), 100, params=params)
33+
34+
# Calculating knn intersection measure for different index types on refinement
35+
KIM_fp16 = knn_intersection_measure(I_fp16, ds.get_groundtruth())
36+
KIM_sq8 = knn_intersection_measure(I_sq8, ds.get_groundtruth())
37+
38+
# KNN intersection measure accuracy shows that choosing SQ8 impacts accuracy
39+
assert (KIM_fp16 > KIM_sq8)
40+
41+
print(I_sq8[:5])
42+
print(I_fp16[:5])

0 commit comments

Comments
 (0)