Skip to content

Commit 09c7aac

Browse files
Jeff Johnsonfacebook-github-bot
Jeff Johnson
authored andcommitted
Faiss GPU CUDA 12 fix: warp synchronous behavior
Summary: This diff fixes the bug associated with moving Faiss GPU to CUDA 12. The following tests were succeeding in CUDA 11.x but failed in CUDA 12: ``` ✗ faiss/gpu/test:test_gpu_basics_py - test_input_types (faiss.gpu.test.test_gpu_basics.TestKnn) ✗ faiss/gpu/test:test_gpu_basics_py - test_dist (faiss.gpu.test.test_gpu_basics.TestAllPairwiseDistance) ✗ faiss/gpu/test:test_gpu_index_ivfpq - TestGpuIndexIVFPQ.Add_L2 ✗ faiss/gpu/test:test_gpu_basics_py - test_input_types_tiling (faiss.gpu.test.test_gpu_basics.TestKnn) ✗ faiss/gpu/test:test_gpu_index_ivfpq - TestGpuIndexIVFPQ.Add_IP ✗ faiss/gpu/test:test_gpu_index_ivfpq - TestGpuIndexIVFPQ.Float16Coarse ✗ faiss/gpu/test:test_gpu_index_ivfpq - TestGpuIndexIVFPQ.LargeBatch ``` It took a long while to track down, but the issue presented itself when an odd number of dimensions not divisible by 32 was used in cases where we needed to calculate a L2 norm for vectors, which occurred with brute-force L2 distance computation, as well as certain L2 IVFPQ operations. This issue appeared as some tests were using 33 as the dimensionality of vectors. The issue is that the number of threads given to the L2 norm kernel was effectively `min(dims, 1024)` where 1024 is the standard maximum number of CUDA threads per CTA on all devices at present. In the case where the result was not a multiple of 32, this would result in a partial warp being passed to the kernel (with non-participating lanes having no side effects). The change in CUDA 12 here seemed to be a change in the compiler behavior for warp-synchronous shuffle instructions (such as `__shfl_up_sync`. In the case of the partial warp, we were passing `0xffffffff` as the active lane mask, implying that all lanes were present for the warp. In the case of dims = 33, we would have 1 full warp with all lanes present, and 1 partial warp with only 1 active thread, so `0xffffffff` is a lie in this case. Prior to CUDA 12, it appeared that these shuffle instructions may have passed 0? around for lanes not present (or would it stall?), so the result was still calculated correctly. However, with the change to CUDA 12, the compiler and/or device firmware (or something) interprets this differently, where the warp lanes not present were providing garbage. The shuffle instructions were used to perform in-warp reductions (e.g., summing a bunch of floating point numbers), namely those needed to sum up the L2 vector norm value. So for dims = 32 or dims = 64 (and bizarrely, dims = 40 and some other choices) it still worked, but for dims = 33 it was adding in garbage, producing erroneous results. This diff removes the non-dim loop functionality for runL2Norm (where we can statically avoid a for loop over dimensions in case our threadblock is exactly sized with the number of dimensions present) and we just use the general-purpose fallback. Second, we now always provide an even number of warps when running the L2 norm kernel, avoiding the issue with the warp synchronous instructions not having a full warp present. This bug has been present since the code was written 2016 and was technically wrong before, but is only surfaced to be a bug/problem with the CUDA 12 change. tl;dr: if you use any kind of `_sync` instruction involving warp sync, always have a whole number of warps present, k thx. Reviewed By: mdouze Differential Revision: D51335172 fbshipit-source-id: 97da88a8dcbe6b4d8963083abc01d5d2121478bf
1 parent 0c2243c commit 09c7aac

File tree

1 file changed

+33
-76
lines changed

1 file changed

+33
-76
lines changed

faiss/gpu/impl/L2Norm.cu

+33-76
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,7 @@ namespace gpu {
3131
// T: the type we are doing the math in (e.g., float, half)
3232
// TVec: the potentially vectorized type we are loading in (e.g.,
3333
// float4, half2)
34-
template <
35-
typename T,
36-
typename TVec,
37-
int RowTileSize,
38-
bool NormLoop,
39-
bool NormSquared>
34+
template <typename T, typename TVec, int RowTileSize, bool NormSquared>
4035
__global__ void l2NormRowMajor(
4136
Tensor<TVec, 2, true> input,
4237
Tensor<float, 1, true> output) {
@@ -56,19 +51,13 @@ __global__ void l2NormRowMajor(
5651
if (lastRowTile) {
5752
// We are handling the very end of the input matrix rows
5853
for (idx_t row = 0; row < input.getSize(0) - rowStart; ++row) {
59-
if (NormLoop) {
60-
rowNorm[0] = 0;
61-
62-
for (idx_t col = threadIdx.x; col < input.getSize(1);
63-
col += blockDim.x) {
64-
TVec val = input[rowStart + row][col];
65-
val = Math<TVec>::mul(val, val);
66-
rowNorm[0] = rowNorm[0] + Math<TVec>::reduceAdd(val);
67-
}
68-
} else {
69-
TVec val = input[rowStart + row][threadIdx.x];
54+
rowNorm[0] = 0;
55+
56+
for (idx_t col = threadIdx.x; col < input.getSize(1);
57+
col += blockDim.x) {
58+
TVec val = input[rowStart + row][col];
7059
val = Math<TVec>::mul(val, val);
71-
rowNorm[0] = Math<TVec>::reduceAdd(val);
60+
rowNorm[0] = rowNorm[0] + Math<TVec>::reduceAdd(val);
7261
}
7362

7463
rowNorm[0] = warpReduceAllSum(rowNorm[0]);
@@ -79,42 +68,18 @@ __global__ void l2NormRowMajor(
7968
} else {
8069
// We are guaranteed that all RowTileSize rows are available in
8170
// [rowStart, rowStart + RowTileSize)
82-
83-
if (NormLoop) {
84-
// A single block of threads is not big enough to span each
85-
// vector
86-
TVec tmp[RowTileSize];
87-
88-
#pragma unroll
89-
for (int row = 0; row < RowTileSize; ++row) {
90-
rowNorm[row] = 0;
91-
}
92-
93-
for (idx_t col = threadIdx.x; col < input.getSize(1);
94-
col += blockDim.x) {
95-
#pragma unroll
96-
for (int row = 0; row < RowTileSize; ++row) {
97-
tmp[row] = input[rowStart + row][col];
98-
}
99-
100-
#pragma unroll
101-
for (int row = 0; row < RowTileSize; ++row) {
102-
tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
103-
}
71+
TVec tmp[RowTileSize];
10472

10573
#pragma unroll
106-
for (int row = 0; row < RowTileSize; ++row) {
107-
rowNorm[row] =
108-
rowNorm[row] + Math<TVec>::reduceAdd(tmp[row]);
109-
}
110-
}
111-
} else {
112-
TVec tmp[RowTileSize];
74+
for (int row = 0; row < RowTileSize; ++row) {
75+
rowNorm[row] = 0;
76+
}
11377

114-
// A block of threads is the exact size of the vector
78+
for (idx_t col = threadIdx.x; col < input.getSize(1);
79+
col += blockDim.x) {
11580
#pragma unroll
11681
for (int row = 0; row < RowTileSize; ++row) {
117-
tmp[row] = input[rowStart + row][threadIdx.x];
82+
tmp[row] = input[rowStart + row][col];
11883
}
11984

12085
#pragma unroll
@@ -124,7 +89,7 @@ __global__ void l2NormRowMajor(
12489

12590
#pragma unroll
12691
for (int row = 0; row < RowTileSize; ++row) {
127-
rowNorm[row] = Math<TVec>::reduceAdd(tmp[row]);
92+
rowNorm[row] = rowNorm[row] + Math<TVec>::reduceAdd(tmp[row]);
12893
}
12994
}
13095

@@ -161,7 +126,7 @@ __global__ void l2NormRowMajor(
161126
if (laneId == 0) {
162127
#pragma unroll
163128
for (int row = 0; row < RowTileSize; ++row) {
164-
int outCol = rowStart + row;
129+
idx_t outCol = rowStart + row;
165130

166131
if (lastRowTile) {
167132
if (outCol < output.getSize(0)) {
@@ -218,25 +183,15 @@ void runL2Norm(
218183
idx_t maxThreads = (idx_t)getMaxThreadsCurrentDevice();
219184
constexpr int rowTileSize = 8;
220185

221-
#define RUN_L2_ROW_MAJOR(TYPE_T, TYPE_TVEC, INPUT) \
222-
do { \
223-
if (normLoop) { \
224-
if (normSquared) { \
225-
l2NormRowMajor<TYPE_T, TYPE_TVEC, rowTileSize, true, true> \
226-
<<<grid, block, smem, stream>>>(INPUT, output); \
227-
} else { \
228-
l2NormRowMajor<TYPE_T, TYPE_TVEC, rowTileSize, true, false> \
229-
<<<grid, block, smem, stream>>>(INPUT, output); \
230-
} \
231-
} else { \
232-
if (normSquared) { \
233-
l2NormRowMajor<TYPE_T, TYPE_TVEC, rowTileSize, false, true> \
234-
<<<grid, block, smem, stream>>>(INPUT, output); \
235-
} else { \
236-
l2NormRowMajor<TYPE_T, TYPE_TVEC, rowTileSize, false, false> \
237-
<<<grid, block, smem, stream>>>(INPUT, output); \
238-
} \
239-
} \
186+
#define RUN_L2_ROW_MAJOR(TYPE_T, TYPE_TVEC, INPUT) \
187+
do { \
188+
if (normSquared) { \
189+
l2NormRowMajor<TYPE_T, TYPE_TVEC, rowTileSize, true> \
190+
<<<grid, block, smem, stream>>>(INPUT, output); \
191+
} else { \
192+
l2NormRowMajor<TYPE_T, TYPE_TVEC, rowTileSize, false> \
193+
<<<grid, block, smem, stream>>>(INPUT, output); \
194+
} \
240195
} while (0)
241196

242197
if (inputRowMajor) {
@@ -247,10 +202,11 @@ void runL2Norm(
247202
if (input.template canCastResize<TVec>()) {
248203
// Can load using the vectorized type
249204
auto inputV = input.template castResize<TVec>();
250-
251205
auto dim = inputV.getSize(1);
252-
bool normLoop = dim > maxThreads;
253-
auto numThreads = std::min(dim, maxThreads);
206+
207+
// We must always have full warps present
208+
auto numThreads =
209+
std::min(utils::roundUp(dim, kWarpSize), maxThreads);
254210

255211
auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize));
256212
auto block = dim3(numThreads);
@@ -261,10 +217,11 @@ void runL2Norm(
261217
RUN_L2_ROW_MAJOR(T, TVec, inputV);
262218
} else {
263219
// Can't load using the vectorized type
264-
265220
auto dim = input.getSize(1);
266-
bool normLoop = dim > maxThreads;
267-
auto numThreads = std::min(dim, maxThreads);
221+
222+
// We must always have full warps present
223+
auto numThreads =
224+
std::min(utils::roundUp(dim, kWarpSize), maxThreads);
268225

269226
auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize));
270227
auto block = dim3(numThreads);

0 commit comments

Comments
 (0)