Skip to content

Commit 9fce55c

Browse files
Michael Norrisfacebook-github-bot
Michael Norris
authored andcommitted
Add reverse factory string util, add StringIOReader, add centralized JK (#3879)
Summary: 1. Adds JK `faiss/telemetry:use_faiss_telemetry_core` to the top level logging util in `wrapper_logging_utils.h`. This is currently set to false. I plan to deprecate the other knobs under https://www.internalfb.com/intern/justknobs/?name=faiss%2Ftelemetry and just use one, as Unicorn can't really have their own JK easily (they subclass a lot of FAISS classes too). 2. Copied StringIOReader from Unicorn to telemetry wrapper in `io.h`. This will be deleted from Unicorn in the follow up diff. 3. Updated Laser tests to reflect correct index_read factory string changes. 4. Adds reverse_index_factory. More tests for it in subsequent diff. Reviewed By: junjieqi Differential Revision: D62670316
1 parent 03f1d2a commit 9fce55c

File tree

3 files changed

+230
-0
lines changed

3 files changed

+230
-0
lines changed

faiss/cppcontrib/factory_tools.cpp

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
// -*- c++ -*-
9+
10+
#include <faiss/cppcontrib/factory_tools.h>
11+
#include <map>
12+
13+
namespace faiss {
14+
15+
namespace {
16+
17+
const std::map<faiss::ScalarQuantizer::QuantizerType, std::string> sq_types = {
18+
{faiss::ScalarQuantizer::QT_8bit, "SQ8"},
19+
{faiss::ScalarQuantizer::QT_4bit, "SQ4"},
20+
{faiss::ScalarQuantizer::QT_6bit, "SQ6"},
21+
{faiss::ScalarQuantizer::QT_fp16, "SQfp16"},
22+
{faiss::ScalarQuantizer::QT_bf16, "SQbf16"},
23+
{faiss::ScalarQuantizer::QT_8bit_direct_signed, "SQ8_direct_signed"},
24+
{faiss::ScalarQuantizer::QT_8bit_direct, "SQ8_direct"},
25+
};
26+
27+
int get_hnsw_M(const faiss::IndexHNSW* index) {
28+
if (index->hnsw.cum_nneighbor_per_level.size() >= 1) {
29+
return index->hnsw.cum_nneighbor_per_level[1] / 2;
30+
}
31+
// Avoid runtime error, just return 0.
32+
return 0;
33+
}
34+
35+
} // namespace
36+
37+
// Reference for reverse_index_factory:
38+
// https://github.com/facebookresearch/faiss/blob/838612c9d7f2f619811434ec9209c020f44107cb/contrib/factory_tools.py#L81
39+
std::string reverse_index_factory(const faiss::Index* index) {
40+
std::string prefix;
41+
if (dynamic_cast<const faiss::IndexFlat*>(index)) {
42+
return "Flat";
43+
} else if (
44+
const faiss::IndexIVF* ivf_index =
45+
dynamic_cast<const faiss::IndexIVF*>(index)) {
46+
const faiss::Index* quantizer = ivf_index->quantizer;
47+
48+
if (dynamic_cast<const faiss::IndexFlat*>(quantizer)) {
49+
prefix = "IVF" + std::to_string(ivf_index->nlist);
50+
} else if (
51+
const faiss::MultiIndexQuantizer* miq =
52+
dynamic_cast<const faiss::MultiIndexQuantizer*>(
53+
quantizer)) {
54+
prefix = "IMI" + std::to_string(miq->pq.M) + "x" +
55+
std::to_string(miq->pq.nbits);
56+
} else if (
57+
const faiss::IndexHNSW* hnsw_index =
58+
dynamic_cast<const faiss::IndexHNSW*>(quantizer)) {
59+
prefix = "IVF" + std::to_string(ivf_index->nlist) + "_HNSW" +
60+
std::to_string(get_hnsw_M(hnsw_index));
61+
} else {
62+
prefix = "IVF" + std::to_string(ivf_index->nlist) + "(" +
63+
reverse_index_factory(quantizer) + ")";
64+
}
65+
66+
if (dynamic_cast<const faiss::IndexIVFFlat*>(ivf_index)) {
67+
return prefix + ",Flat";
68+
} else if (
69+
auto sq_index =
70+
dynamic_cast<const faiss::IndexIVFScalarQuantizer*>(
71+
ivf_index)) {
72+
return prefix + "," + sq_types.at(sq_index->sq.qtype);
73+
} else if (
74+
const faiss::IndexIVFPQ* ivfpq_index =
75+
dynamic_cast<const faiss::IndexIVFPQ*>(ivf_index)) {
76+
return prefix + ",PQ" + std::to_string(ivfpq_index->pq.M) + "x" +
77+
std::to_string(ivfpq_index->pq.nbits);
78+
} else if (
79+
const faiss::IndexIVFPQFastScan* ivfpqfs_index =
80+
dynamic_cast<const faiss::IndexIVFPQFastScan*>(
81+
ivf_index)) {
82+
return prefix + ",PQ" + std::to_string(ivfpqfs_index->pq.M) + "x" +
83+
std::to_string(ivfpqfs_index->pq.nbits) + "fs";
84+
}
85+
} else if (
86+
const faiss::IndexPreTransform* pretransform_index =
87+
dynamic_cast<const faiss::IndexPreTransform*>(index)) {
88+
if (pretransform_index->chain.size() != 1) {
89+
// Avoid runtime error, just return empty string for logging.
90+
return "";
91+
}
92+
const faiss::VectorTransform* vt = pretransform_index->chain.at(0);
93+
if (const faiss::OPQMatrix* opq_matrix =
94+
dynamic_cast<const faiss::OPQMatrix*>(vt)) {
95+
prefix = "OPQ" + std::to_string(opq_matrix->M) + "_" +
96+
std::to_string(opq_matrix->d_out);
97+
} else if (
98+
const faiss::ITQTransform* itq_transform =
99+
dynamic_cast<const faiss::ITQTransform*>(vt)) {
100+
prefix = "ITQ" + std::to_string(itq_transform->itq.d_out);
101+
} else if (
102+
const faiss::PCAMatrix* pca_matrix =
103+
dynamic_cast<const faiss::PCAMatrix*>(vt)) {
104+
assert(pca_matrix->eigen_power == 0);
105+
prefix = "PCA" +
106+
std::string(pca_matrix->random_rotation ? "R" : "") +
107+
std::to_string(pca_matrix->d_out);
108+
} else {
109+
// Avoid runtime error, just return empty string for logging.
110+
return "";
111+
}
112+
return prefix + "," + reverse_index_factory(pretransform_index->index);
113+
} else if (
114+
const faiss::IndexHNSW* hnsw_index =
115+
dynamic_cast<const faiss::IndexHNSW*>(index)) {
116+
return "HNSW" + std::to_string(get_hnsw_M(hnsw_index));
117+
} else if (
118+
const faiss::IndexRefine* refine_index =
119+
dynamic_cast<const faiss::IndexRefine*>(index)) {
120+
return reverse_index_factory(refine_index->base_index) + ",Refine(" +
121+
reverse_index_factory(refine_index->refine_index) + ")";
122+
} else if (
123+
const faiss::IndexPQFastScan* pqfs_index =
124+
dynamic_cast<const faiss::IndexPQFastScan*>(index)) {
125+
return std::string("PQ") + std::to_string(pqfs_index->pq.M) + "x" +
126+
std::to_string(pqfs_index->pq.nbits) + "fs";
127+
} else if (
128+
const faiss::IndexPQ* pq_index =
129+
dynamic_cast<const faiss::IndexPQ*>(index)) {
130+
return std::string("PQ") + std::to_string(pq_index->pq.M) + "x" +
131+
std::to_string(pq_index->pq.nbits);
132+
} else if (
133+
const faiss::IndexLSH* lsh_index =
134+
dynamic_cast<const faiss::IndexLSH*>(index)) {
135+
std::string result = "LSH";
136+
if (lsh_index->rotate_data) {
137+
result += "r";
138+
}
139+
if (lsh_index->train_thresholds) {
140+
result += "t";
141+
}
142+
return result;
143+
} else if (
144+
const faiss::IndexScalarQuantizer* sq_index =
145+
dynamic_cast<const faiss::IndexScalarQuantizer*>(index)) {
146+
return std::string("SQ") + sq_types.at(sq_index->sq.qtype);
147+
}
148+
// Avoid runtime error, just return empty string for logging.
149+
return "";
150+
}
151+
152+
} // namespace faiss

faiss/cppcontrib/factory_tools.h

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
// -*- c++ -*-
9+
10+
#pragma once
11+
12+
#include <faiss/IndexHNSW.h>
13+
#include <faiss/IndexIVFFlat.h>
14+
#include <faiss/IndexIVFPQFastScan.h>
15+
#include <faiss/IndexLSH.h>
16+
#include <faiss/IndexPQFastScan.h>
17+
#include <faiss/IndexPreTransform.h>
18+
#include <faiss/IndexRefine.h>
19+
20+
namespace faiss {
21+
22+
std::string reverse_index_factory(const faiss::Index* index);
23+
24+
} // namespace faiss

tests/test_factory_tools.cpp

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#include <faiss/cppcontrib/factory_tools.h>
4+
#include <faiss/index_factory.h>
5+
#include <gtest/gtest.h>
6+
7+
using namespace faiss;
8+
9+
TEST(TestFactoryTools, TestReverseIndexFactory) {
10+
auto factory_string = "Flat";
11+
auto index = faiss::index_factory(64, factory_string);
12+
EXPECT_EQ(factory_string, reverse_index_factory(index));
13+
delete index;
14+
15+
factory_string = "IMI2x5,PQ8x8";
16+
index = faiss::index_factory(32, factory_string);
17+
EXPECT_EQ(factory_string, reverse_index_factory(index));
18+
delete index;
19+
20+
factory_string = "IVF32_HNSW32,SQ8";
21+
index = faiss::index_factory(64, factory_string);
22+
EXPECT_EQ(factory_string, reverse_index_factory(index));
23+
delete index;
24+
25+
factory_string = "IVF8,Flat";
26+
index = faiss::index_factory(64, factory_string);
27+
EXPECT_EQ(factory_string, reverse_index_factory(index));
28+
delete index;
29+
30+
factory_string = "IVF8,SQ4";
31+
index = faiss::index_factory(64, factory_string);
32+
EXPECT_EQ(factory_string, reverse_index_factory(index));
33+
delete index;
34+
35+
factory_string = "IVF8,PQ4x8";
36+
index = faiss::index_factory(64, factory_string);
37+
EXPECT_EQ(factory_string, reverse_index_factory(index));
38+
delete index;
39+
40+
factory_string = "LSHrt";
41+
index = faiss::index_factory(64, factory_string);
42+
EXPECT_EQ(factory_string, reverse_index_factory(index));
43+
delete index;
44+
45+
factory_string = "PQ4x8";
46+
index = faiss::index_factory(64, factory_string);
47+
EXPECT_EQ(factory_string, reverse_index_factory(index));
48+
delete index;
49+
50+
factory_string = "HNSW32";
51+
index = faiss::index_factory(64, factory_string);
52+
EXPECT_EQ(factory_string, reverse_index_factory(index));
53+
delete index;
54+
}

0 commit comments

Comments
 (0)