|
| 1 | +/** |
| 2 | + * Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | + * |
| 4 | + * This source code is licensed under the MIT license found in the |
| 5 | + * LICENSE file in the root directory of this source tree. |
| 6 | + */ |
| 7 | + |
| 8 | +// -*- c++ -*- |
| 9 | + |
| 10 | +#include <faiss/cppcontrib/factory_tools.h> |
| 11 | +#include <map> |
| 12 | + |
| 13 | +namespace faiss { |
| 14 | + |
| 15 | +namespace { |
| 16 | + |
| 17 | +const std::map<faiss::ScalarQuantizer::QuantizerType, std::string> sq_types = { |
| 18 | + {faiss::ScalarQuantizer::QT_8bit, "SQ8"}, |
| 19 | + {faiss::ScalarQuantizer::QT_4bit, "SQ4"}, |
| 20 | + {faiss::ScalarQuantizer::QT_6bit, "SQ6"}, |
| 21 | + {faiss::ScalarQuantizer::QT_fp16, "SQfp16"}, |
| 22 | + {faiss::ScalarQuantizer::QT_bf16, "SQbf16"}, |
| 23 | + {faiss::ScalarQuantizer::QT_8bit_direct_signed, "SQ8_direct_signed"}, |
| 24 | + {faiss::ScalarQuantizer::QT_8bit_direct, "SQ8_direct"}, |
| 25 | +}; |
| 26 | + |
| 27 | +int get_hnsw_M(const faiss::IndexHNSW* index) { |
| 28 | + if (index->hnsw.cum_nneighbor_per_level.size() >= 1) { |
| 29 | + return index->hnsw.cum_nneighbor_per_level[1] / 2; |
| 30 | + } |
| 31 | + // Avoid runtime error, just return 0. |
| 32 | + return 0; |
| 33 | +} |
| 34 | + |
| 35 | +} // namespace |
| 36 | + |
| 37 | +// Reference for reverse_index_factory: |
| 38 | +// https://github.com/facebookresearch/faiss/blob/838612c9d7f2f619811434ec9209c020f44107cb/contrib/factory_tools.py#L81 |
| 39 | +std::string reverse_index_factory(const faiss::Index* index) { |
| 40 | + std::string prefix; |
| 41 | + if (dynamic_cast<const faiss::IndexFlat*>(index)) { |
| 42 | + return "Flat"; |
| 43 | + } else if ( |
| 44 | + const faiss::IndexIVF* ivf_index = |
| 45 | + dynamic_cast<const faiss::IndexIVF*>(index)) { |
| 46 | + const faiss::Index* quantizer = ivf_index->quantizer; |
| 47 | + |
| 48 | + if (dynamic_cast<const faiss::IndexFlat*>(quantizer)) { |
| 49 | + prefix = "IVF" + std::to_string(ivf_index->nlist); |
| 50 | + } else if ( |
| 51 | + const faiss::MultiIndexQuantizer* miq = |
| 52 | + dynamic_cast<const faiss::MultiIndexQuantizer*>( |
| 53 | + quantizer)) { |
| 54 | + prefix = "IMI" + std::to_string(miq->pq.M) + "x" + |
| 55 | + std::to_string(miq->pq.nbits); |
| 56 | + } else if ( |
| 57 | + const faiss::IndexHNSW* hnsw_index = |
| 58 | + dynamic_cast<const faiss::IndexHNSW*>(quantizer)) { |
| 59 | + prefix = "IVF" + std::to_string(ivf_index->nlist) + "_HNSW" + |
| 60 | + std::to_string(get_hnsw_M(hnsw_index)); |
| 61 | + } else { |
| 62 | + prefix = "IVF" + std::to_string(ivf_index->nlist) + "(" + |
| 63 | + reverse_index_factory(quantizer) + ")"; |
| 64 | + } |
| 65 | + |
| 66 | + if (dynamic_cast<const faiss::IndexIVFFlat*>(ivf_index)) { |
| 67 | + return prefix + ",Flat"; |
| 68 | + } else if ( |
| 69 | + auto sq_index = |
| 70 | + dynamic_cast<const faiss::IndexIVFScalarQuantizer*>( |
| 71 | + ivf_index)) { |
| 72 | + return prefix + "," + sq_types.at(sq_index->sq.qtype); |
| 73 | + } else if ( |
| 74 | + const faiss::IndexIVFPQ* ivfpq_index = |
| 75 | + dynamic_cast<const faiss::IndexIVFPQ*>(ivf_index)) { |
| 76 | + return prefix + ",PQ" + std::to_string(ivfpq_index->pq.M) + "x" + |
| 77 | + std::to_string(ivfpq_index->pq.nbits); |
| 78 | + } else if ( |
| 79 | + const faiss::IndexIVFPQFastScan* ivfpqfs_index = |
| 80 | + dynamic_cast<const faiss::IndexIVFPQFastScan*>( |
| 81 | + ivf_index)) { |
| 82 | + return prefix + ",PQ" + std::to_string(ivfpqfs_index->pq.M) + "x" + |
| 83 | + std::to_string(ivfpqfs_index->pq.nbits) + "fs"; |
| 84 | + } |
| 85 | + } else if ( |
| 86 | + const faiss::IndexPreTransform* pretransform_index = |
| 87 | + dynamic_cast<const faiss::IndexPreTransform*>(index)) { |
| 88 | + if (pretransform_index->chain.size() != 1) { |
| 89 | + // Avoid runtime error, just return empty string for logging. |
| 90 | + return ""; |
| 91 | + } |
| 92 | + const faiss::VectorTransform* vt = pretransform_index->chain.at(0); |
| 93 | + if (const faiss::OPQMatrix* opq_matrix = |
| 94 | + dynamic_cast<const faiss::OPQMatrix*>(vt)) { |
| 95 | + prefix = "OPQ" + std::to_string(opq_matrix->M) + "_" + |
| 96 | + std::to_string(opq_matrix->d_out); |
| 97 | + } else if ( |
| 98 | + const faiss::ITQTransform* itq_transform = |
| 99 | + dynamic_cast<const faiss::ITQTransform*>(vt)) { |
| 100 | + prefix = "ITQ" + std::to_string(itq_transform->itq.d_out); |
| 101 | + } else if ( |
| 102 | + const faiss::PCAMatrix* pca_matrix = |
| 103 | + dynamic_cast<const faiss::PCAMatrix*>(vt)) { |
| 104 | + assert(pca_matrix->eigen_power == 0); |
| 105 | + prefix = "PCA" + |
| 106 | + std::string(pca_matrix->random_rotation ? "R" : "") + |
| 107 | + std::to_string(pca_matrix->d_out); |
| 108 | + } else { |
| 109 | + // Avoid runtime error, just return empty string for logging. |
| 110 | + return ""; |
| 111 | + } |
| 112 | + return prefix + "," + reverse_index_factory(pretransform_index->index); |
| 113 | + } else if ( |
| 114 | + const faiss::IndexHNSW* hnsw_index = |
| 115 | + dynamic_cast<const faiss::IndexHNSW*>(index)) { |
| 116 | + return "HNSW" + std::to_string(get_hnsw_M(hnsw_index)); |
| 117 | + } else if ( |
| 118 | + const faiss::IndexRefine* refine_index = |
| 119 | + dynamic_cast<const faiss::IndexRefine*>(index)) { |
| 120 | + return reverse_index_factory(refine_index->base_index) + ",Refine(" + |
| 121 | + reverse_index_factory(refine_index->refine_index) + ")"; |
| 122 | + } else if ( |
| 123 | + const faiss::IndexPQFastScan* pqfs_index = |
| 124 | + dynamic_cast<const faiss::IndexPQFastScan*>(index)) { |
| 125 | + return std::string("PQ") + std::to_string(pqfs_index->pq.M) + "x" + |
| 126 | + std::to_string(pqfs_index->pq.nbits) + "fs"; |
| 127 | + } else if ( |
| 128 | + const faiss::IndexPQ* pq_index = |
| 129 | + dynamic_cast<const faiss::IndexPQ*>(index)) { |
| 130 | + return std::string("PQ") + std::to_string(pq_index->pq.M) + "x" + |
| 131 | + std::to_string(pq_index->pq.nbits); |
| 132 | + } else if ( |
| 133 | + const faiss::IndexLSH* lsh_index = |
| 134 | + dynamic_cast<const faiss::IndexLSH*>(index)) { |
| 135 | + std::string result = "LSH"; |
| 136 | + if (lsh_index->rotate_data) { |
| 137 | + result += "r"; |
| 138 | + } |
| 139 | + if (lsh_index->train_thresholds) { |
| 140 | + result += "t"; |
| 141 | + } |
| 142 | + return result; |
| 143 | + } else if ( |
| 144 | + const faiss::IndexScalarQuantizer* sq_index = |
| 145 | + dynamic_cast<const faiss::IndexScalarQuantizer*>(index)) { |
| 146 | + return std::string("SQ") + sq_types.at(sq_index->sq.qtype); |
| 147 | + } |
| 148 | + // Avoid runtime error, just return empty string for logging. |
| 149 | + return ""; |
| 150 | +} |
| 151 | + |
| 152 | +} // namespace faiss |
0 commit comments