Skip to content

Commit 3d228c8

Browse files
Jeff Johnsonfacebook-github-bot
Jeff Johnson
authored andcommitted
Faiss GPU: bfloat16 brute-force kNN support (#4014)
Summary: This diff adds support for bfloat16 vector/query data types with the GPU brute-force k-nearest neighbor function (`bfKnn`). The change is largely just plumbing the new data type through the template hierarchy (so distances can be computed in bfloat16). Of note, by design, all final distance results are produced in float32 regardless of input data type (float32, float16, bfloat16). This is because the true nearest neighbors in many data sets can often differ by only ~1000 float32 ULPs in terms of distance which will result in possible false equivalency. This seems to be one area where lossy compression/quantization thoughout does not work as well (and is also why `CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION` is set in `StandardGpuResources.cpp`. However, given that there is native bf16 x bf16 = fp32 tensor core support on Ampere+ architectures, the matrix multiplication itself should WARNING: The one thing this diff does not yet handle properly is header inclusion / compilation for GPUs older than Ampere. This will need to be fixed before landing (so that compiling with an older CUDA SDK or compiling for the Volta architecture will simply error out at runtime properly with lack of support, instead of failing to compile (?) Differential Revision: D65459723
1 parent cfd4804 commit 3d228c8

13 files changed

+698
-46
lines changed

faiss/gpu/GpuDistance.cu

+15-5
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
242242
FAISS_THROW_IF_NOT_MSG(
243243
args.vectorType == args.queryType,
244244
"limitation: both vectorType and queryType must currently "
245-
"be the same (F32 or F16");
245+
"be the same (F32 / F16 / BF16");
246246

247247
#if defined USE_NVIDIA_RAFT
248248
// Note: For now, RAFT bfknn requires queries and vectors to be same layout
@@ -374,6 +374,12 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
374374
bfKnnConvert<float>(prov, args);
375375
} else if (args.vectorType == DistanceDataType::F16) {
376376
bfKnnConvert<half>(prov, args);
377+
} else if (args.vectorType == DistanceDataType::BF16) {
378+
#ifdef FAISS_USE_FULL_BFLOAT16
379+
bfKnnConvert<__nv_bfloat16>(prov, args);
380+
#else
381+
FAISS_THROW_MSG("not compiled with bfloat16 support");
382+
#endif
377383
} else {
378384
FAISS_THROW_MSG("unknown vectorType");
379385
}
@@ -440,8 +446,10 @@ void bfKnn_single_query_shard(
440446
args.k > 0,
441447
"bfKnn_tiling: tiling vectors is only supported for k > 0");
442448
size_t distance_size = args.vectorType == DistanceDataType::F32 ? 4
443-
: args.vectorType == DistanceDataType::F16 ? 2
444-
: 0;
449+
: (args.vectorType == DistanceDataType::F16 ||
450+
args.vectorType == DistanceDataType::BF16)
451+
? 2
452+
: 0;
445453
FAISS_THROW_IF_NOT_MSG(
446454
distance_size > 0, "bfKnn_tiling: unknown vectorType");
447455
size_t shard_size = vectorsMemoryLimit / (args.dims * distance_size);
@@ -498,8 +506,10 @@ void bfKnn_tiling(
498506
args.k > 0,
499507
"bfKnn_tiling: tiling queries is only supported for k > 0");
500508
size_t distance_size = args.queryType == DistanceDataType::F32 ? 4
501-
: args.queryType == DistanceDataType::F16 ? 2
502-
: 0;
509+
: (args.queryType == DistanceDataType::F16 ||
510+
args.queryType == DistanceDataType::BF16)
511+
? 2
512+
: 0;
503513
FAISS_THROW_IF_NOT_MSG(
504514
distance_size > 0, "bfKnn_tiling: unknown queryType");
505515
size_t label_size = args.outIndicesType == IndicesDataType::I64 ? 8

faiss/gpu/GpuDistance.h

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class GpuResourcesProvider;
1919
enum class DistanceDataType {
2020
F32 = 1,
2121
F16,
22+
BF16,
2223
};
2324

2425
// Scalar type of the indices data

faiss/gpu/impl/Distance.cu

+97
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,29 @@ void runAllPairwiseL2Distance(
504504
outDistances);
505505
}
506506

507+
#ifdef FAISS_USE_FULL_BFLOAT16
508+
void runAllPairwiseL2Distance(
509+
GpuResources* res,
510+
cudaStream_t stream,
511+
Tensor<__nv_bfloat16, 2, true>& vectors,
512+
bool vectorsRowMajor,
513+
Tensor<float, 1, true>* vectorNorms,
514+
Tensor<__nv_bfloat16, 2, true>& queries,
515+
bool queriesRowMajor,
516+
Tensor<float, 2, true>& outDistances) {
517+
runAllPairwiseDistance<__nv_bfloat16>(
518+
true,
519+
res,
520+
stream,
521+
vectors,
522+
vectorsRowMajor,
523+
vectorNorms,
524+
queries,
525+
queriesRowMajor,
526+
outDistances);
527+
}
528+
#endif // FAISS_USE_FULL_BFLOAT16
529+
507530
void runAllPairwiseIPDistance(
508531
GpuResources* res,
509532
cudaStream_t stream,
@@ -544,6 +567,28 @@ void runAllPairwiseIPDistance(
544567
outDistances);
545568
}
546569

570+
#ifdef FAISS_USE_FULL_BFLOAT16
571+
void runAllPairwiseIPDistance(
572+
GpuResources* res,
573+
cudaStream_t stream,
574+
Tensor<__nv_bfloat16, 2, true>& vectors,
575+
bool vectorsRowMajor,
576+
Tensor<__nv_bfloat16, 2, true>& queries,
577+
bool queriesRowMajor,
578+
Tensor<float, 2, true>& outDistances) {
579+
runAllPairwiseDistance<__nv_bfloat16>(
580+
false,
581+
res,
582+
stream,
583+
vectors,
584+
vectorsRowMajor,
585+
nullptr,
586+
queries,
587+
queriesRowMajor,
588+
outDistances);
589+
}
590+
#endif // FAISS_USE_FULL_BFLOAT16
591+
547592
void runL2Distance(
548593
GpuResources* res,
549594
cudaStream_t stream,
@@ -596,6 +641,34 @@ void runL2Distance(
596641
ignoreOutDistances);
597642
}
598643

644+
#ifdef FAISS_USE_FULL_BFLOAT16
645+
void runL2Distance(
646+
GpuResources* res,
647+
cudaStream_t stream,
648+
Tensor<__nv_bfloat16, 2, true>& vectors,
649+
bool vectorsRowMajor,
650+
Tensor<float, 1, true>* vectorNorms,
651+
Tensor<__nv_bfloat16, 2, true>& queries,
652+
bool queriesRowMajor,
653+
int k,
654+
Tensor<float, 2, true>& outDistances,
655+
Tensor<idx_t, 2, true>& outIndices,
656+
bool ignoreOutDistances) {
657+
runL2Distance<__nv_bfloat16>(
658+
res,
659+
stream,
660+
vectors,
661+
vectorsRowMajor,
662+
vectorNorms,
663+
queries,
664+
queriesRowMajor,
665+
k,
666+
outDistances,
667+
outIndices,
668+
ignoreOutDistances);
669+
}
670+
#endif // FAISS_USE_FULL_BFLOAT16
671+
599672
void runIPDistance(
600673
GpuResources* res,
601674
cudaStream_t stream,
@@ -640,5 +713,29 @@ void runIPDistance(
640713
outIndices);
641714
}
642715

716+
#ifdef FAISS_USE_FULL_BFLOAT16
717+
void runIPDistance(
718+
GpuResources* res,
719+
cudaStream_t stream,
720+
Tensor<__nv_bfloat16, 2, true>& vectors,
721+
bool vectorsRowMajor,
722+
Tensor<__nv_bfloat16, 2, true>& queries,
723+
bool queriesRowMajor,
724+
int k,
725+
Tensor<float, 2, true>& outDistances,
726+
Tensor<idx_t, 2, true>& outIndices) {
727+
runIPDistance<__nv_bfloat16>(
728+
res,
729+
stream,
730+
vectors,
731+
vectorsRowMajor,
732+
queries,
733+
queriesRowMajor,
734+
k,
735+
outDistances,
736+
outIndices);
737+
}
738+
#endif // FAISS_USE_FULL_BFLOAT16
739+
643740
} // namespace gpu
644741
} // namespace faiss

faiss/gpu/impl/Distance.cuh

+51
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,18 @@ void runAllPairwiseL2Distance(
4141
bool queriesRowMajor,
4242
Tensor<float, 2, true>& outDistances);
4343

44+
#ifdef FAISS_USE_FULL_BFLOAT16
45+
void runAllPairwiseL2Distance(
46+
GpuResources* res,
47+
cudaStream_t stream,
48+
Tensor<__nv_bfloat16, 2, true>& vectors,
49+
bool vectorsRowMajor,
50+
Tensor<float, 1, true>* vectorNorms,
51+
Tensor<__nv_bfloat16, 2, true>& queries,
52+
bool queriesRowMajor,
53+
Tensor<float, 2, true>& outDistances);
54+
#endif // FAISS_USE_FULL_BFLOAT16
55+
4456
void runAllPairwiseIPDistance(
4557
GpuResources* res,
4658
cudaStream_t stream,
@@ -59,6 +71,17 @@ void runAllPairwiseIPDistance(
5971
bool queriesRowMajor,
6072
Tensor<float, 2, true>& outDistances);
6173

74+
#ifdef FAISS_USE_FULL_BFLOAT16
75+
void runAllPairwiseIPDistance(
76+
GpuResources* res,
77+
cudaStream_t stream,
78+
Tensor<__nv_bfloat16, 2, true>& vectors,
79+
bool vectorsRowMajor,
80+
Tensor<__nv_bfloat16, 2, true>& queries,
81+
bool queriesRowMajor,
82+
Tensor<float, 2, true>& outDistances);
83+
#endif // FAISS_USE_FULL_BFLOAT16
84+
6285
/// Calculates brute-force L2 distance between `vectors` and
6386
/// `queries`, returning the k closest results seen
6487
void runL2Distance(
@@ -91,6 +114,21 @@ void runL2Distance(
91114
Tensor<idx_t, 2, true>& outIndices,
92115
bool ignoreOutDistances = false);
93116

117+
#ifdef FAISS_USE_FULL_BFLOAT16
118+
void runL2Distance(
119+
GpuResources* resources,
120+
cudaStream_t stream,
121+
Tensor<__nv_bfloat16, 2, true>& vectors,
122+
bool vectorsRowMajor,
123+
Tensor<float, 1, true>* vectorNorms,
124+
Tensor<__nv_bfloat16, 2, true>& queries,
125+
bool queriesRowMajor,
126+
int k,
127+
Tensor<float, 2, true>& outDistances,
128+
Tensor<idx_t, 2, true>& outIndices,
129+
bool ignoreOutDistances = false);
130+
#endif // FAISS_USE_FULL_BFLOAT16
131+
94132
/// Calculates brute-force inner product distance between `vectors`
95133
/// and `queries`, returning the k closest results seen
96134
void runIPDistance(
@@ -115,6 +153,19 @@ void runIPDistance(
115153
Tensor<float, 2, true>& outDistances,
116154
Tensor<idx_t, 2, true>& outIndices);
117155

156+
#ifdef FAISS_USE_FULL_BFLOAT16
157+
void runIPDistance(
158+
GpuResources* resources,
159+
cudaStream_t stream,
160+
Tensor<__nv_bfloat16, 2, true>& vectors,
161+
bool vectorsRowMajor,
162+
Tensor<__nv_bfloat16, 2, true>& queries,
163+
bool queriesRowMajor,
164+
int k,
165+
Tensor<float, 2, true>& outDistances,
166+
Tensor<idx_t, 2, true>& outIndices);
167+
#endif // FAISS_USE_FULL_BFLOAT16
168+
118169
//
119170
// General distance implementation, assumes that all arguments are on the
120171
// device. This is the top-level internal distance function to call to dispatch

faiss/gpu/impl/GpuScalarQuantizer.cuh

+4-4
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
154154
inline __device__ void decode(void* data, idx_t vec, int d, float* out)
155155
const {
156156
half* p = (half*)&((uint8_t*)data)[vec * bytesPerVec];
157-
out[0] = Convert<half, float>()(p[d]);
157+
out[0] = ConvertTo<float>::to(p[d]);
158158
}
159159

160160
inline __device__ float decodePartial(
@@ -172,7 +172,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
172172
int d,
173173
float v[kDimPerIter]) const {
174174
half* p = (half*)&((uint8_t*)data)[vec * bytesPerVec];
175-
p[d] = Convert<float, half>()(v[0]);
175+
p[d] = ConvertTo<half>::to(v[0]);
176176
}
177177

178178
inline __device__ void encodePartial(
@@ -191,11 +191,11 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
191191
static constexpr int kEncodeBits = 16;
192192

193193
inline __device__ EncodeT encodeNew(int dim, float v) const {
194-
return Convert<float, half>()(v);
194+
return ConvertTo<half>::to(v);
195195
}
196196

197197
inline __device__ float decodeNew(int dim, EncodeT v) const {
198-
return Convert<half, float>()(v);
198+
return ConvertTo<float>::to(v);
199199
}
200200

201201
int bytesPerVec;

faiss/gpu/impl/L2Norm.cu

+14-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include <faiss/gpu/impl/L2Norm.cuh>
1212
#include <faiss/gpu/utils/ConversionOperators.cuh>
1313
#include <faiss/gpu/utils/DeviceDefs.cuh>
14-
#include <faiss/gpu/utils/Float16.cuh>
1514
#include <faiss/gpu/utils/MathOperators.cuh>
1615
#include <faiss/gpu/utils/PtxUtils.cuh>
1716
#include <faiss/gpu/utils/Reductions.cuh>
@@ -276,5 +275,19 @@ void runL2Norm(
276275
runL2Norm<half, half2>(input, inputRowMajor, output, normSquared, stream);
277276
}
278277

278+
#ifdef FAISS_USE_FULL_BFLOAT16
279+
280+
void runL2Norm(
281+
Tensor<__nv_bfloat16, 2, true>& input,
282+
bool inputRowMajor,
283+
Tensor<float, 1, true>& output,
284+
bool normSquared,
285+
cudaStream_t stream) {
286+
runL2Norm<__nv_bfloat16, __nv_bfloat162>(
287+
input, inputRowMajor, output, normSquared, stream);
288+
}
289+
290+
#endif // FAISS_USE_FULL_BFLOAT16
291+
279292
} // namespace gpu
280293
} // namespace faiss

faiss/gpu/impl/L2Norm.cuh

+12-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
#pragma once
99

10-
#include <cuda_fp16.h>
10+
#include <faiss/gpu/utils/Float16.cuh>
1111
#include <faiss/gpu/utils/Tensor.cuh>
1212

1313
namespace faiss {
@@ -27,5 +27,16 @@ void runL2Norm(
2727
bool normSquared,
2828
cudaStream_t stream);
2929

30+
#ifdef FAISS_USE_FULL_BFLOAT16
31+
32+
void runL2Norm(
33+
Tensor<__nv_bfloat16, 2, true>& input,
34+
bool inputRowMajor,
35+
Tensor<float, 1, true>& output,
36+
bool normSquared,
37+
cudaStream_t stream);
38+
39+
#endif // FAISS_USE_FULL_BFLOAT16
40+
3041
} // namespace gpu
3142
} // namespace faiss

faiss/gpu/impl/VectorResidual.cu

+2-6
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,8 @@ __global__ void gatherReconstructByIds(
114114
auto vec = vecs[id];
115115
auto outVec = out[blockIdx.x];
116116

117-
Convert<T, float> conv;
118-
119117
for (idx_t i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) {
120-
outVec[i] = id == idx_t(-1) ? 0.0f : conv(vec[i]);
118+
outVec[i] = id == idx_t(-1) ? 0.0f : ConvertTo<float>::to(vec[i]);
121119
}
122120
}
123121

@@ -131,10 +129,8 @@ __global__ void gatherReconstructByRange(
131129
auto vec = vecs[id];
132130
auto outVec = out[blockIdx.x];
133131

134-
Convert<T, float> conv;
135-
136132
for (idx_t i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) {
137-
outVec[i] = id == idx_t(-1) ? 0.0f : conv(vec[i]);
133+
outVec[i] = id == idx_t(-1) ? 0.0f : ConvertTo<float>::to(vec[i]);
138134
}
139135
}
140136

0 commit comments

Comments
 (0)