Skip to content

Commit 4bfdd43

Browse files
r-barnesfacebook-github-bot
authored andcommitted
Parallelize kernel compilation in FAISS (facebookresearch#2922)
Summary: Pull Request resolved: facebookresearch#2922 This parallelizes kernel compilation by taking a template function from much deeper in the stack than was previously the case and generating 128 compilation units rather than the original 8. Reviewed By: mdouze Differential Revision: D46674315 fbshipit-source-id: 830eeaf43dee2c081f735be47c809b28aa3a05f6
1 parent a91a288 commit 4bfdd43

13 files changed

+179
-279
lines changed

faiss/gpu/CMakeLists.txt

+68-8
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,6 @@ set(FAISS_GPU_SRC
5252
impl/PQScanMultiPassPrecomputed.cu
5353
impl/RemapIndices.cpp
5454
impl/VectorResidual.cu
55-
impl/scan/IVFInterleaved1.cu
56-
impl/scan/IVFInterleaved32.cu
57-
impl/scan/IVFInterleaved64.cu
58-
impl/scan/IVFInterleaved128.cu
59-
impl/scan/IVFInterleaved256.cu
60-
impl/scan/IVFInterleaved512.cu
61-
impl/scan/IVFInterleaved1024.cu
62-
impl/scan/IVFInterleaved2048.cu
6355
impl/IcmEncoder.cu
6456
utils/BlockSelectFloat.cu
6557
utils/DeviceUtils.cu
@@ -176,6 +168,74 @@ set(FAISS_GPU_HEADERS
176168
utils/warpselect/WarpSelectImpl.cuh
177169
)
178170

171+
function(generate_ivf_interleaved_code)
172+
set(SUB_CODEC_TYPE
173+
"faiss::gpu::Codec<0, 1>"
174+
"faiss::gpu::Codec<1, 1>"
175+
"faiss::gpu::Codec<2, 1>"
176+
"faiss::gpu::Codec<3, 1>"
177+
"faiss::gpu::Codec<4, 1>"
178+
"faiss::gpu::Codec<5, 1>"
179+
"faiss::gpu::Codec<6, 1>"
180+
"faiss::gpu::CodecFloat"
181+
)
182+
183+
set(SUB_METRIC_TYPE
184+
"faiss::gpu::IPDistance"
185+
"faiss::gpu::L2Distance"
186+
)
187+
188+
# Used for SUB_THREADS, SUB_NUM_WARP_Q, SUB_NUM_THREAD_Q
189+
set(THREADS_AND_WARPS
190+
"128|1024|8"
191+
"128|1|1"
192+
"128|128|3"
193+
"128|256|4"
194+
"128|32|2"
195+
"128|512|8"
196+
"128|64|3"
197+
"64|2048|8"
198+
)
199+
200+
# Traverse through the Cartesian product of X and Y
201+
foreach(sub_codec ${SUB_CODEC_TYPE})
202+
foreach(metric_type ${SUB_METRIC_TYPE})
203+
foreach(threads_and_warps_str ${THREADS_AND_WARPS})
204+
string(REPLACE "|" ";" threads_and_warps ${threads_and_warps_str})
205+
list(GET threads_and_warps 0 sub_threads)
206+
list(GET threads_and_warps 1 sub_num_warp_q)
207+
list(GET threads_and_warps 2 sub_num_thread_q)
208+
209+
# Define the output file name
210+
set(filename "template_${sub_codec}_${metric_type}_${sub_threads}_${sub_num_warp_q}_${sub_num_thread_q}")
211+
# Remove illegal characters from filename
212+
string(REGEX REPLACE "[^A-Za-z0-9_]" "" filename ${filename})
213+
set(output_file "${CMAKE_CURRENT_BINARY_DIR}/${filename}.cu")
214+
215+
# Read the template file
216+
file(READ "${CMAKE_CURRENT_SOURCE_DIR}/impl/scan/IVFInterleavedScanKernelTemplate.cu" template_content)
217+
218+
# Replace the placeholders
219+
string(REPLACE "SUB_CODEC_TYPE" "${sub_codec}" template_content "${template_content}")
220+
string(REPLACE "SUB_METRIC_TYPE" "${metric_type}" template_content "${template_content}")
221+
string(REPLACE "SUB_THREADS" "${sub_threads}" template_content "${template_content}")
222+
string(REPLACE "SUB_NUM_WARP_Q" "${sub_num_warp_q}" template_content "${template_content}")
223+
string(REPLACE "SUB_NUM_THREAD_Q" "${sub_num_thread_q}" template_content "${template_content}")
224+
225+
# Write the modified content to the output file
226+
file(WRITE "${output_file}" "${template_content}")
227+
228+
# Add the file to the sources
229+
list(APPEND FAISS_GPU_SRC "${output_file}")
230+
endforeach()
231+
endforeach()
232+
endforeach()
233+
# Propagate modified variable to the parent scope
234+
set(FAISS_GPU_SRC "${FAISS_GPU_SRC}" PARENT_SCOPE)
235+
endfunction()
236+
237+
generate_ivf_interleaved_code()
238+
179239
if(FAISS_ENABLE_RAFT)
180240
list(APPEND FAISS_GPU_HEADERS
181241
impl/RaftFlatIndex.cuh)

faiss/gpu/impl/IVFInterleaved.cu

+8-10
Original file line numberDiff line numberDiff line change
@@ -210,25 +210,23 @@ void runIVFInterleavedScan(
210210
};
211211

212212
if (k == 1) {
213-
ivf_interleaved_call(ivfInterleavedScanImpl<IVFINTERLEAVED_1_PARAMS>);
213+
ivf_interleaved_call(ivfInterleavedScanImpl<128, 1, 1>);
214214
} else if (k <= 32) {
215-
ivf_interleaved_call(ivfInterleavedScanImpl<IVFINTERLEAVED_32_PARAMS>);
215+
ivf_interleaved_call(ivfInterleavedScanImpl<128, 32, 2>);
216216
} else if (k <= 64) {
217-
ivf_interleaved_call(ivfInterleavedScanImpl<IVFINTERLEAVED_64_PARAMS>);
217+
ivf_interleaved_call(ivfInterleavedScanImpl<128, 64, 3>);
218218
} else if (k <= 128) {
219-
ivf_interleaved_call(ivfInterleavedScanImpl<IVFINTERLEAVED_128_PARAMS>);
219+
ivf_interleaved_call(ivfInterleavedScanImpl<128, 128, 3>);
220220
} else if (k <= 256) {
221-
ivf_interleaved_call(ivfInterleavedScanImpl<IVFINTERLEAVED_256_PARAMS>);
221+
ivf_interleaved_call(ivfInterleavedScanImpl<128, 256, 4>);
222222
} else if (k <= 512) {
223-
ivf_interleaved_call(ivfInterleavedScanImpl<IVFINTERLEAVED_512_PARAMS>);
223+
ivf_interleaved_call(ivfInterleavedScanImpl<128, 512, 8>);
224224
} else if (k <= 1024) {
225-
ivf_interleaved_call(
226-
ivfInterleavedScanImpl<IVFINTERLEAVED_1024_PARAMS>);
225+
ivf_interleaved_call(ivfInterleavedScanImpl<128, 1024, 8>);
227226
}
228227
#if GPU_MAX_SELECTION_K >= 2048
229228
else if (k <= 2048) {
230-
ivf_interleaved_call(
231-
ivfInterleavedScanImpl<IVFINTERLEAVED_2048_PARAMS>);
229+
ivf_interleaved_call(ivfInterleavedScanImpl<64, 2048, 8>);
232230
}
233231
#endif
234232
}

faiss/gpu/impl/IVFInterleaved.cuh

+9-9
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ template <
3535
typename Metric,
3636
int ThreadsPerBlock,
3737
int NumWarpQ,
38-
int NumThreadQ,
39-
bool Residual>
38+
int NumThreadQ>
4039
__global__ void ivfInterleavedScan(
4140
Tensor<float, 2, true> queries,
4241
Tensor<float, 3, true> residualBase,
@@ -48,7 +47,8 @@ __global__ void ivfInterleavedScan(
4847
int k,
4948
// [query][probe][k]
5049
Tensor<float, 3, true> distanceOut,
51-
Tensor<idx_t, 3, true> indicesOut) {
50+
Tensor<idx_t, 3, true> indicesOut,
51+
const bool Residual) {
5252
extern __shared__ float smem[];
5353

5454
constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
@@ -124,7 +124,7 @@ __global__ void ivfInterleavedScan(
124124
for (int dBase = 0; dBase < dimBlocks; dBase += kWarpSize) {
125125
const int loadDim = dBase + laneId;
126126
const float queryReg = query[loadDim];
127-
[[maybe_unused]] const float residualReg =
127+
const float residualReg =
128128
Residual ? residualBaseSlice[loadDim] : 0;
129129

130130
constexpr int kUnroll = 4;
@@ -152,7 +152,7 @@ __global__ void ivfInterleavedScan(
152152
decV[j] = codec.decodeNew(dBase + d, encV[j]);
153153
}
154154

155-
if constexpr (Residual) {
155+
if (Residual) {
156156
#pragma unroll
157157
for (int j = 0; j < kUnroll; ++j) {
158158
int d = i * kUnroll + j;
@@ -174,9 +174,9 @@ __global__ void ivfInterleavedScan(
174174
const bool loadDimInBounds = loadDim < dim;
175175

176176
const float queryReg = loadDimInBounds ? query[loadDim] : 0;
177-
[[maybe_unused]] const float residualReg =
178-
Residual && loadDimInBounds ? residualBaseSlice[loadDim]
179-
: 0;
177+
const float residualReg = Residual && loadDimInBounds
178+
? residualBaseSlice[loadDim]
179+
: 0;
180180

181181
for (int d = 0; d < dim - dimBlocks;
182182
++d, data += wordsPerVectorBlockDim) {
@@ -187,7 +187,7 @@ __global__ void ivfInterleavedScan(
187187
enc = WarpPackedBits<EncodeT, Codec::kEncodeBits>::postRead(
188188
laneId, enc);
189189
float dec = codec.decodeNew(dimBlocks + d, enc);
190-
if constexpr (Residual) {
190+
if (Residual) {
191191
dec += SHFL_SYNC(residualReg, d, kWarpSize);
192192
}
193193

faiss/gpu/impl/scan/IVFInterleaved1.cu

-16
This file was deleted.

faiss/gpu/impl/scan/IVFInterleaved1024.cu

-16
This file was deleted.

faiss/gpu/impl/scan/IVFInterleaved128.cu

-16
This file was deleted.

faiss/gpu/impl/scan/IVFInterleaved2048.cu

-18
This file was deleted.

faiss/gpu/impl/scan/IVFInterleaved256.cu

-16
This file was deleted.

faiss/gpu/impl/scan/IVFInterleaved32.cu

-16
This file was deleted.

faiss/gpu/impl/scan/IVFInterleaved512.cu

-16
This file was deleted.

faiss/gpu/impl/scan/IVFInterleaved64.cu

-16
This file was deleted.

0 commit comments

Comments
 (0)