Skip to content

Commit c7c8620

Browse files
Jeff Johnsonfacebook-github-bot
Jeff Johnson
authored andcommitted
Faiss GPU: bfloat16 brute-force kNN support (facebookresearch#4018)
Summary: This diff adds support for bfloat16 vector/query data types with the GPU brute-force k-nearest neighbor function (`bfKnn`). The change is largely just plumbing the new data type through the template hierarchy (so distances can be computed in bfloat16). Of note, by design, all final distance results are produced in float32 regardless of input data type (float32, float16, bfloat16). This is because the true nearest neighbors in many data sets can often differ by only ~1000 float32 ULPs in terms of distance which will result in possible false equivalency. This seems to be one area where lossy compression/quantization thoughout does not work as well (and is also why `CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION` is set in `StandardGpuResources.cpp`. However, given that there is native bf16 x bf16 = fp32 tensor core support on Ampere+ architectures, the matrix multiplication itself should WARNING: The one thing this diff does not yet handle properly is header inclusion / compilation for GPUs older than Ampere. This will need to be fixed before landing (so that compiling with an older CUDA SDK or compiling for the Volta architecture will simply error out at runtime properly with lack of support, instead of failing to compile (?) Differential Revision: D65459723
1 parent adb1884 commit c7c8620

17 files changed

+742
-50
lines changed

faiss/gpu/GpuDistance.cu

+16-5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <faiss/gpu/utils/ConversionOperators.cuh>
3131
#include <faiss/gpu/utils/CopyUtils.cuh>
3232
#include <faiss/gpu/utils/DeviceTensor.cuh>
33+
#include <faiss/gpu/utils/Float16.cuh>
3334

3435
#if defined USE_NVIDIA_RAFT
3536
#include <faiss/gpu/utils/RaftUtils.h>
@@ -242,7 +243,7 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
242243
FAISS_THROW_IF_NOT_MSG(
243244
args.vectorType == args.queryType,
244245
"limitation: both vectorType and queryType must currently "
245-
"be the same (F32 or F16");
246+
"be the same (F32 / F16 / BF16");
246247

247248
#if defined USE_NVIDIA_RAFT
248249
// Note: For now, RAFT bfknn requires queries and vectors to be same layout
@@ -374,6 +375,12 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
374375
bfKnnConvert<float>(prov, args);
375376
} else if (args.vectorType == DistanceDataType::F16) {
376377
bfKnnConvert<half>(prov, args);
378+
} else if (args.vectorType == DistanceDataType::BF16) {
379+
if (prov->getResources()->supportsBFloat16CurrentDevice()) {
380+
bfKnnConvert<__nv_bfloat16>(prov, args);
381+
} else {
382+
FAISS_THROW_MSG("not compiled with bfloat16 support");
383+
}
377384
} else {
378385
FAISS_THROW_MSG("unknown vectorType");
379386
}
@@ -440,8 +447,10 @@ void bfKnn_single_query_shard(
440447
args.k > 0,
441448
"bfKnn_tiling: tiling vectors is only supported for k > 0");
442449
size_t distance_size = args.vectorType == DistanceDataType::F32 ? 4
443-
: args.vectorType == DistanceDataType::F16 ? 2
444-
: 0;
450+
: (args.vectorType == DistanceDataType::F16 ||
451+
args.vectorType == DistanceDataType::BF16)
452+
? 2
453+
: 0;
445454
FAISS_THROW_IF_NOT_MSG(
446455
distance_size > 0, "bfKnn_tiling: unknown vectorType");
447456
size_t shard_size = vectorsMemoryLimit / (args.dims * distance_size);
@@ -498,8 +507,10 @@ void bfKnn_tiling(
498507
args.k > 0,
499508
"bfKnn_tiling: tiling queries is only supported for k > 0");
500509
size_t distance_size = args.queryType == DistanceDataType::F32 ? 4
501-
: args.queryType == DistanceDataType::F16 ? 2
502-
: 0;
510+
: (args.queryType == DistanceDataType::F16 ||
511+
args.queryType == DistanceDataType::BF16)
512+
? 2
513+
: 0;
503514
FAISS_THROW_IF_NOT_MSG(
504515
distance_size > 0, "bfKnn_tiling: unknown queryType");
505516
size_t label_size = args.outIndicesType == IndicesDataType::I64 ? 8

faiss/gpu/GpuDistance.h

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class GpuResourcesProvider;
1919
enum class DistanceDataType {
2020
F32 = 1,
2121
F16,
22+
BF16,
2223
};
2324

2425
// Scalar type of the indices data

faiss/gpu/GpuResources.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ GpuMemoryReservation::~GpuMemoryReservation() {
161161

162162
GpuResources::~GpuResources() = default;
163163

164+
bool GpuResources::supportsBFloat16CurrentDevice() {
165+
return supportsBFloat16(getCurrentDevice());
166+
}
167+
164168
cublasHandle_t GpuResources::getBlasHandleCurrentDevice() {
165169
return getBlasHandle(getCurrentDevice());
166170
}

faiss/gpu/GpuResources.h

+6
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ class GpuResources {
205205
/// of demand
206206
virtual void initializeForDevice(int device) = 0;
207207

208+
/// Does the given GPU support bfloat16?
209+
virtual bool supportsBFloat16(int device) = 0;
210+
208211
/// Returns the cuBLAS handle that we use for the given device
209212
virtual cublasHandle_t getBlasHandle(int device) = 0;
210213

@@ -252,6 +255,9 @@ class GpuResources {
252255
/// Functions provided by default
253256
///
254257

258+
/// Does the current GPU support bfloat16?
259+
bool supportsBFloat16CurrentDevice();
260+
255261
/// Calls getBlasHandle with the current device
256262
cublasHandle_t getBlasHandleCurrentDevice();
257263

faiss/gpu/StandardGpuResources.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,13 @@ size_t StandardGpuResourcesImpl::getDefaultTempMemForGPU(
202202
return requested;
203203
}
204204

205+
/// Does the given GPU support bfloat16?
206+
bool StandardGpuResourcesImpl::supportsBFloat16(int device) {
207+
initializeForDevice(device);
208+
auto& prop = getDeviceProperties(device);
209+
return prop.major >= 8;
210+
}
211+
205212
void StandardGpuResourcesImpl::noTempMemory() {
206213
setTempMemory(0);
207214
}
@@ -687,6 +694,14 @@ std::shared_ptr<GpuResources> StandardGpuResources::getResources() {
687694
return res_;
688695
}
689696

697+
bool StandardGpuResources::supportsBFloat16(int device) {
698+
return res_->supportsBFloat16(device);
699+
}
700+
701+
bool StandardGpuResources::supportsBFloat16CurrentDevice() {
702+
return res_->supportsBFloat16CurrentDevice();
703+
}
704+
690705
void StandardGpuResources::noTempMemory() {
691706
res_->noTempMemory();
692707
}

faiss/gpu/StandardGpuResources.h

+9
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class StandardGpuResourcesImpl : public GpuResources {
4848

4949
~StandardGpuResourcesImpl() override;
5050

51+
/// Does the given GPU support bfloat16?
52+
bool supportsBFloat16(int device) override;
53+
5154
/// Disable allocation of temporary memory; all temporary memory
5255
/// requests will call cudaMalloc / cudaFree at the point of use
5356
void noTempMemory();
@@ -199,6 +202,12 @@ class StandardGpuResources : public GpuResourcesProvider {
199202

200203
std::shared_ptr<GpuResources> getResources() override;
201204

205+
/// Whether or not the given device supports native bfloat16 arithmetic
206+
bool supportsBFloat16(int device);
207+
208+
/// Whether or not the current device supports native bfloat16 arithmetic
209+
bool supportsBFloat16CurrentDevice();
210+
202211
/// Disable allocation of temporary memory; all temporary memory
203212
/// requests will call cudaMalloc / cudaFree at the point of use
204213
void noTempMemory();

faiss/gpu/impl/Distance.cu

+89
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,27 @@ void runAllPairwiseL2Distance(
504504
outDistances);
505505
}
506506

507+
void runAllPairwiseL2Distance(
508+
GpuResources* res,
509+
cudaStream_t stream,
510+
Tensor<__nv_bfloat16, 2, true>& vectors,
511+
bool vectorsRowMajor,
512+
Tensor<float, 1, true>* vectorNorms,
513+
Tensor<__nv_bfloat16, 2, true>& queries,
514+
bool queriesRowMajor,
515+
Tensor<float, 2, true>& outDistances) {
516+
runAllPairwiseDistance<__nv_bfloat16>(
517+
true,
518+
res,
519+
stream,
520+
vectors,
521+
vectorsRowMajor,
522+
vectorNorms,
523+
queries,
524+
queriesRowMajor,
525+
outDistances);
526+
}
527+
507528
void runAllPairwiseIPDistance(
508529
GpuResources* res,
509530
cudaStream_t stream,
@@ -544,6 +565,26 @@ void runAllPairwiseIPDistance(
544565
outDistances);
545566
}
546567

568+
void runAllPairwiseIPDistance(
569+
GpuResources* res,
570+
cudaStream_t stream,
571+
Tensor<__nv_bfloat16, 2, true>& vectors,
572+
bool vectorsRowMajor,
573+
Tensor<__nv_bfloat16, 2, true>& queries,
574+
bool queriesRowMajor,
575+
Tensor<float, 2, true>& outDistances) {
576+
runAllPairwiseDistance<__nv_bfloat16>(
577+
false,
578+
res,
579+
stream,
580+
vectors,
581+
vectorsRowMajor,
582+
nullptr,
583+
queries,
584+
queriesRowMajor,
585+
outDistances);
586+
}
587+
547588
void runL2Distance(
548589
GpuResources* res,
549590
cudaStream_t stream,
@@ -596,6 +637,32 @@ void runL2Distance(
596637
ignoreOutDistances);
597638
}
598639

640+
void runL2Distance(
641+
GpuResources* res,
642+
cudaStream_t stream,
643+
Tensor<__nv_bfloat16, 2, true>& vectors,
644+
bool vectorsRowMajor,
645+
Tensor<float, 1, true>* vectorNorms,
646+
Tensor<__nv_bfloat16, 2, true>& queries,
647+
bool queriesRowMajor,
648+
int k,
649+
Tensor<float, 2, true>& outDistances,
650+
Tensor<idx_t, 2, true>& outIndices,
651+
bool ignoreOutDistances) {
652+
runL2Distance<__nv_bfloat16>(
653+
res,
654+
stream,
655+
vectors,
656+
vectorsRowMajor,
657+
vectorNorms,
658+
queries,
659+
queriesRowMajor,
660+
k,
661+
outDistances,
662+
outIndices,
663+
ignoreOutDistances);
664+
}
665+
599666
void runIPDistance(
600667
GpuResources* res,
601668
cudaStream_t stream,
@@ -640,5 +707,27 @@ void runIPDistance(
640707
outIndices);
641708
}
642709

710+
void runIPDistance(
711+
GpuResources* res,
712+
cudaStream_t stream,
713+
Tensor<__nv_bfloat16, 2, true>& vectors,
714+
bool vectorsRowMajor,
715+
Tensor<__nv_bfloat16, 2, true>& queries,
716+
bool queriesRowMajor,
717+
int k,
718+
Tensor<float, 2, true>& outDistances,
719+
Tensor<idx_t, 2, true>& outIndices) {
720+
runIPDistance<__nv_bfloat16>(
721+
res,
722+
stream,
723+
vectors,
724+
vectorsRowMajor,
725+
queries,
726+
queriesRowMajor,
727+
k,
728+
outDistances,
729+
outIndices);
730+
}
731+
643732
} // namespace gpu
644733
} // namespace faiss

faiss/gpu/impl/Distance.cuh

+43
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ void runAllPairwiseL2Distance(
4141
bool queriesRowMajor,
4242
Tensor<float, 2, true>& outDistances);
4343

44+
void runAllPairwiseL2Distance(
45+
GpuResources* res,
46+
cudaStream_t stream,
47+
Tensor<__nv_bfloat16, 2, true>& vectors,
48+
bool vectorsRowMajor,
49+
Tensor<float, 1, true>* vectorNorms,
50+
Tensor<__nv_bfloat16, 2, true>& queries,
51+
bool queriesRowMajor,
52+
Tensor<float, 2, true>& outDistances);
53+
4454
void runAllPairwiseIPDistance(
4555
GpuResources* res,
4656
cudaStream_t stream,
@@ -59,6 +69,15 @@ void runAllPairwiseIPDistance(
5969
bool queriesRowMajor,
6070
Tensor<float, 2, true>& outDistances);
6171

72+
void runAllPairwiseIPDistance(
73+
GpuResources* res,
74+
cudaStream_t stream,
75+
Tensor<__nv_bfloat16, 2, true>& vectors,
76+
bool vectorsRowMajor,
77+
Tensor<__nv_bfloat16, 2, true>& queries,
78+
bool queriesRowMajor,
79+
Tensor<float, 2, true>& outDistances);
80+
6281
/// Calculates brute-force L2 distance between `vectors` and
6382
/// `queries`, returning the k closest results seen
6483
void runL2Distance(
@@ -91,6 +110,19 @@ void runL2Distance(
91110
Tensor<idx_t, 2, true>& outIndices,
92111
bool ignoreOutDistances = false);
93112

113+
void runL2Distance(
114+
GpuResources* resources,
115+
cudaStream_t stream,
116+
Tensor<__nv_bfloat16, 2, true>& vectors,
117+
bool vectorsRowMajor,
118+
Tensor<float, 1, true>* vectorNorms,
119+
Tensor<__nv_bfloat16, 2, true>& queries,
120+
bool queriesRowMajor,
121+
int k,
122+
Tensor<float, 2, true>& outDistances,
123+
Tensor<idx_t, 2, true>& outIndices,
124+
bool ignoreOutDistances = false);
125+
94126
/// Calculates brute-force inner product distance between `vectors`
95127
/// and `queries`, returning the k closest results seen
96128
void runIPDistance(
@@ -115,6 +147,17 @@ void runIPDistance(
115147
Tensor<float, 2, true>& outDistances,
116148
Tensor<idx_t, 2, true>& outIndices);
117149

150+
void runIPDistance(
151+
GpuResources* resources,
152+
cudaStream_t stream,
153+
Tensor<__nv_bfloat16, 2, true>& vectors,
154+
bool vectorsRowMajor,
155+
Tensor<__nv_bfloat16, 2, true>& queries,
156+
bool queriesRowMajor,
157+
int k,
158+
Tensor<float, 2, true>& outDistances,
159+
Tensor<idx_t, 2, true>& outIndices);
160+
118161
//
119162
// General distance implementation, assumes that all arguments are on the
120163
// device. This is the top-level internal distance function to call to dispatch

faiss/gpu/impl/GpuScalarQuantizer.cuh

+4-4
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
154154
inline __device__ void decode(void* data, idx_t vec, int d, float* out)
155155
const {
156156
half* p = (half*)&((uint8_t*)data)[vec * bytesPerVec];
157-
out[0] = Convert<half, float>()(p[d]);
157+
out[0] = ConvertTo<float>::to(p[d]);
158158
}
159159

160160
inline __device__ float decodePartial(
@@ -172,7 +172,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
172172
int d,
173173
float v[kDimPerIter]) const {
174174
half* p = (half*)&((uint8_t*)data)[vec * bytesPerVec];
175-
p[d] = Convert<float, half>()(v[0]);
175+
p[d] = ConvertTo<half>::to(v[0]);
176176
}
177177

178178
inline __device__ void encodePartial(
@@ -191,11 +191,11 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
191191
static constexpr int kEncodeBits = 16;
192192

193193
inline __device__ EncodeT encodeNew(int dim, float v) const {
194-
return Convert<float, half>()(v);
194+
return ConvertTo<half>::to(v);
195195
}
196196

197197
inline __device__ float decodeNew(int dim, EncodeT v) const {
198-
return Convert<half, float>()(v);
198+
return ConvertTo<float>::to(v);
199199
}
200200

201201
int bytesPerVec;

faiss/gpu/impl/L2Norm.cu

+10-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include <faiss/gpu/impl/L2Norm.cuh>
1212
#include <faiss/gpu/utils/ConversionOperators.cuh>
1313
#include <faiss/gpu/utils/DeviceDefs.cuh>
14-
#include <faiss/gpu/utils/Float16.cuh>
1514
#include <faiss/gpu/utils/MathOperators.cuh>
1615
#include <faiss/gpu/utils/PtxUtils.cuh>
1716
#include <faiss/gpu/utils/Reductions.cuh>
@@ -276,5 +275,15 @@ void runL2Norm(
276275
runL2Norm<half, half2>(input, inputRowMajor, output, normSquared, stream);
277276
}
278277

278+
void runL2Norm(
279+
Tensor<__nv_bfloat16, 2, true>& input,
280+
bool inputRowMajor,
281+
Tensor<float, 1, true>& output,
282+
bool normSquared,
283+
cudaStream_t stream) {
284+
runL2Norm<__nv_bfloat16, __nv_bfloat162>(
285+
input, inputRowMajor, output, normSquared, stream);
286+
}
287+
279288
} // namespace gpu
280289
} // namespace faiss

0 commit comments

Comments
 (0)