Skip to content

Commit c2b4a61

Browse files
Jeff Johnsonfacebook-github-bot
Jeff Johnson
authored andcommitted
Faiss GPU: bfloat16 brute-force kNN support (facebookresearch#4018)
Summary: This diff adds support for bfloat16 vector/query data types with the GPU brute-force k-nearest neighbor function (`bfKnn`). The change is largely just plumbing the new data type through the template hierarchy (so distances can be computed in bfloat16). Of note, by design, all final distance results are produced in float32 regardless of input data type (float32, float16, bfloat16). This is because the true nearest neighbors in many data sets can often differ by only ~1000 float32 ULPs in terms of distance which will result in possible false equivalency. This seems to be one area where lossy compression/quantization thoughout does not work as well (and is also why `CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION` is set in `StandardGpuResources.cpp`. However, given that there is native bf16 x bf16 = fp32 tensor core support on Ampere+ architectures, the matrix multiplication itself should WARNING: The one thing this diff does not yet handle properly is header inclusion / compilation for GPUs older than Ampere. This will need to be fixed before landing (so that compiling with an older CUDA SDK or compiling for the Volta architecture will simply error out at runtime properly with lack of support, instead of failing to compile (?) Differential Revision: D65459723
1 parent adb1884 commit c2b4a61

18 files changed

+690
-56
lines changed

faiss/gpu/GpuDistance.cu

+21-5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <faiss/gpu/utils/ConversionOperators.cuh>
3131
#include <faiss/gpu/utils/CopyUtils.cuh>
3232
#include <faiss/gpu/utils/DeviceTensor.cuh>
33+
#include <faiss/gpu/utils/Float16.cuh>
3334

3435
#if defined USE_NVIDIA_RAFT
3536
#include <faiss/gpu/utils/RaftUtils.h>
@@ -242,7 +243,7 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
242243
FAISS_THROW_IF_NOT_MSG(
243244
args.vectorType == args.queryType,
244245
"limitation: both vectorType and queryType must currently "
245-
"be the same (F32 or F16");
246+
"be the same (F32 / F16 / BF16");
246247

247248
#if defined USE_NVIDIA_RAFT
248249
// Note: For now, RAFT bfknn requires queries and vectors to be same layout
@@ -374,6 +375,17 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
374375
bfKnnConvert<float>(prov, args);
375376
} else if (args.vectorType == DistanceDataType::F16) {
376377
bfKnnConvert<half>(prov, args);
378+
} else if (args.vectorType == DistanceDataType::BF16) {
379+
// no bf16 support for AMD
380+
#ifndef USE_AMD_ROCM
381+
if (prov->getResources()->supportsBFloat16CurrentDevice()) {
382+
bfKnnConvert<__nv_bfloat16>(prov, args);
383+
} else {
384+
FAISS_THROW_MSG("not compiled with bfloat16 support");
385+
}
386+
#else
387+
FAISS_THROW_MSG("no AMD bfloat16 support");
388+
#endif
377389
} else {
378390
FAISS_THROW_MSG("unknown vectorType");
379391
}
@@ -440,8 +452,10 @@ void bfKnn_single_query_shard(
440452
args.k > 0,
441453
"bfKnn_tiling: tiling vectors is only supported for k > 0");
442454
size_t distance_size = args.vectorType == DistanceDataType::F32 ? 4
443-
: args.vectorType == DistanceDataType::F16 ? 2
444-
: 0;
455+
: (args.vectorType == DistanceDataType::F16 ||
456+
args.vectorType == DistanceDataType::BF16)
457+
? 2
458+
: 0;
445459
FAISS_THROW_IF_NOT_MSG(
446460
distance_size > 0, "bfKnn_tiling: unknown vectorType");
447461
size_t shard_size = vectorsMemoryLimit / (args.dims * distance_size);
@@ -498,8 +512,10 @@ void bfKnn_tiling(
498512
args.k > 0,
499513
"bfKnn_tiling: tiling queries is only supported for k > 0");
500514
size_t distance_size = args.queryType == DistanceDataType::F32 ? 4
501-
: args.queryType == DistanceDataType::F16 ? 2
502-
: 0;
515+
: (args.queryType == DistanceDataType::F16 ||
516+
args.queryType == DistanceDataType::BF16)
517+
? 2
518+
: 0;
503519
FAISS_THROW_IF_NOT_MSG(
504520
distance_size > 0, "bfKnn_tiling: unknown queryType");
505521
size_t label_size = args.outIndicesType == IndicesDataType::I64 ? 8

faiss/gpu/GpuDistance.h

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class GpuResourcesProvider;
1919
enum class DistanceDataType {
2020
F32 = 1,
2121
F16,
22+
BF16,
2223
};
2324

2425
// Scalar type of the indices data

faiss/gpu/GpuResources.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ GpuMemoryReservation::~GpuMemoryReservation() {
161161

162162
GpuResources::~GpuResources() = default;
163163

164+
bool GpuResources::supportsBFloat16CurrentDevice() {
165+
return supportsBFloat16(getCurrentDevice());
166+
}
167+
164168
cublasHandle_t GpuResources::getBlasHandleCurrentDevice() {
165169
return getBlasHandle(getCurrentDevice());
166170
}

faiss/gpu/GpuResources.h

+6
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ class GpuResources {
205205
/// of demand
206206
virtual void initializeForDevice(int device) = 0;
207207

208+
/// Does the given GPU support bfloat16?
209+
virtual bool supportsBFloat16(int device) = 0;
210+
208211
/// Returns the cuBLAS handle that we use for the given device
209212
virtual cublasHandle_t getBlasHandle(int device) = 0;
210213

@@ -252,6 +255,9 @@ class GpuResources {
252255
/// Functions provided by default
253256
///
254257

258+
/// Does the current GPU support bfloat16?
259+
bool supportsBFloat16CurrentDevice();
260+
255261
/// Calls getBlasHandle with the current device
256262
cublasHandle_t getBlasHandleCurrentDevice();
257263

faiss/gpu/StandardGpuResources.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,13 @@ size_t StandardGpuResourcesImpl::getDefaultTempMemForGPU(
202202
return requested;
203203
}
204204

205+
/// Does the given GPU support bfloat16?
206+
bool StandardGpuResourcesImpl::supportsBFloat16(int device) {
207+
initializeForDevice(device);
208+
auto& prop = getDeviceProperties(device);
209+
return prop.major >= 8;
210+
}
211+
205212
void StandardGpuResourcesImpl::noTempMemory() {
206213
setTempMemory(0);
207214
}
@@ -687,6 +694,14 @@ std::shared_ptr<GpuResources> StandardGpuResources::getResources() {
687694
return res_;
688695
}
689696

697+
bool StandardGpuResources::supportsBFloat16(int device) {
698+
return res_->supportsBFloat16(device);
699+
}
700+
701+
bool StandardGpuResources::supportsBFloat16CurrentDevice() {
702+
return res_->supportsBFloat16CurrentDevice();
703+
}
704+
690705
void StandardGpuResources::noTempMemory() {
691706
res_->noTempMemory();
692707
}

faiss/gpu/StandardGpuResources.h

+9
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class StandardGpuResourcesImpl : public GpuResources {
4848

4949
~StandardGpuResourcesImpl() override;
5050

51+
/// Does the given GPU support bfloat16?
52+
bool supportsBFloat16(int device) override;
53+
5154
/// Disable allocation of temporary memory; all temporary memory
5255
/// requests will call cudaMalloc / cudaFree at the point of use
5356
void noTempMemory();
@@ -199,6 +202,12 @@ class StandardGpuResources : public GpuResourcesProvider {
199202

200203
std::shared_ptr<GpuResources> getResources() override;
201204

205+
/// Whether or not the given device supports native bfloat16 arithmetic
206+
bool supportsBFloat16(int device);
207+
208+
/// Whether or not the current device supports native bfloat16 arithmetic
209+
bool supportsBFloat16CurrentDevice();
210+
202211
/// Disable allocation of temporary memory; all temporary memory
203212
/// requests will call cudaMalloc / cudaFree at the point of use
204213
void noTempMemory();

faiss/gpu/impl/Distance.cu

+101
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,30 @@ void runAllPairwiseL2Distance(
504504
outDistances);
505505
}
506506

507+
// no bf16 support for AMD
508+
#ifndef USE_AMD_ROCM
509+
void runAllPairwiseL2Distance(
510+
GpuResources* res,
511+
cudaStream_t stream,
512+
Tensor<__nv_bfloat16, 2, true>& vectors,
513+
bool vectorsRowMajor,
514+
Tensor<float, 1, true>* vectorNorms,
515+
Tensor<__nv_bfloat16, 2, true>& queries,
516+
bool queriesRowMajor,
517+
Tensor<float, 2, true>& outDistances) {
518+
runAllPairwiseDistance<__nv_bfloat16>(
519+
true,
520+
res,
521+
stream,
522+
vectors,
523+
vectorsRowMajor,
524+
vectorNorms,
525+
queries,
526+
queriesRowMajor,
527+
outDistances);
528+
}
529+
#endif // USE_AMD_ROCM
530+
507531
void runAllPairwiseIPDistance(
508532
GpuResources* res,
509533
cudaStream_t stream,
@@ -544,6 +568,29 @@ void runAllPairwiseIPDistance(
544568
outDistances);
545569
}
546570

571+
// no bf16 support for AMD
572+
#ifndef USE_AMD_ROCM
573+
void runAllPairwiseIPDistance(
574+
GpuResources* res,
575+
cudaStream_t stream,
576+
Tensor<__nv_bfloat16, 2, true>& vectors,
577+
bool vectorsRowMajor,
578+
Tensor<__nv_bfloat16, 2, true>& queries,
579+
bool queriesRowMajor,
580+
Tensor<float, 2, true>& outDistances) {
581+
runAllPairwiseDistance<__nv_bfloat16>(
582+
false,
583+
res,
584+
stream,
585+
vectors,
586+
vectorsRowMajor,
587+
nullptr,
588+
queries,
589+
queriesRowMajor,
590+
outDistances);
591+
}
592+
#endif // USE_AMD_ROCM
593+
547594
void runL2Distance(
548595
GpuResources* res,
549596
cudaStream_t stream,
@@ -596,6 +643,35 @@ void runL2Distance(
596643
ignoreOutDistances);
597644
}
598645

646+
// no bf16 support for AMD
647+
#ifndef USE_AMD_ROCM
648+
void runL2Distance(
649+
GpuResources* res,
650+
cudaStream_t stream,
651+
Tensor<__nv_bfloat16, 2, true>& vectors,
652+
bool vectorsRowMajor,
653+
Tensor<float, 1, true>* vectorNorms,
654+
Tensor<__nv_bfloat16, 2, true>& queries,
655+
bool queriesRowMajor,
656+
int k,
657+
Tensor<float, 2, true>& outDistances,
658+
Tensor<idx_t, 2, true>& outIndices,
659+
bool ignoreOutDistances) {
660+
runL2Distance<__nv_bfloat16>(
661+
res,
662+
stream,
663+
vectors,
664+
vectorsRowMajor,
665+
vectorNorms,
666+
queries,
667+
queriesRowMajor,
668+
k,
669+
outDistances,
670+
outIndices,
671+
ignoreOutDistances);
672+
}
673+
#endif // USE_AMD_ROCM
674+
599675
void runIPDistance(
600676
GpuResources* res,
601677
cudaStream_t stream,
@@ -640,5 +716,30 @@ void runIPDistance(
640716
outIndices);
641717
}
642718

719+
// no bf16 support for AMD
720+
#ifndef USE_AMD_ROCM
721+
void runIPDistance(
722+
GpuResources* res,
723+
cudaStream_t stream,
724+
Tensor<__nv_bfloat16, 2, true>& vectors,
725+
bool vectorsRowMajor,
726+
Tensor<__nv_bfloat16, 2, true>& queries,
727+
bool queriesRowMajor,
728+
int k,
729+
Tensor<float, 2, true>& outDistances,
730+
Tensor<idx_t, 2, true>& outIndices) {
731+
runIPDistance<__nv_bfloat16>(
732+
res,
733+
stream,
734+
vectors,
735+
vectorsRowMajor,
736+
queries,
737+
queriesRowMajor,
738+
k,
739+
outDistances,
740+
outIndices);
741+
}
742+
#endif // USE_AMD_ROCM
743+
643744
} // namespace gpu
644745
} // namespace faiss

faiss/gpu/impl/Distance.cuh

+55
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,19 @@ void runAllPairwiseL2Distance(
4141
bool queriesRowMajor,
4242
Tensor<float, 2, true>& outDistances);
4343

44+
// no bf16 support for AMD
45+
#ifndef USE_AMD_ROCM
46+
void runAllPairwiseL2Distance(
47+
GpuResources* res,
48+
cudaStream_t stream,
49+
Tensor<__nv_bfloat16, 2, true>& vectors,
50+
bool vectorsRowMajor,
51+
Tensor<float, 1, true>* vectorNorms,
52+
Tensor<__nv_bfloat16, 2, true>& queries,
53+
bool queriesRowMajor,
54+
Tensor<float, 2, true>& outDistances);
55+
#endif // USE_AMD_ROCM
56+
4457
void runAllPairwiseIPDistance(
4558
GpuResources* res,
4659
cudaStream_t stream,
@@ -59,6 +72,18 @@ void runAllPairwiseIPDistance(
5972
bool queriesRowMajor,
6073
Tensor<float, 2, true>& outDistances);
6174

75+
// no bf16 support for AMD
76+
#ifndef USE_AMD_ROCM
77+
void runAllPairwiseIPDistance(
78+
GpuResources* res,
79+
cudaStream_t stream,
80+
Tensor<__nv_bfloat16, 2, true>& vectors,
81+
bool vectorsRowMajor,
82+
Tensor<__nv_bfloat16, 2, true>& queries,
83+
bool queriesRowMajor,
84+
Tensor<float, 2, true>& outDistances);
85+
#endif // USE_AMD_ROCM
86+
6287
/// Calculates brute-force L2 distance between `vectors` and
6388
/// `queries`, returning the k closest results seen
6489
void runL2Distance(
@@ -91,6 +116,22 @@ void runL2Distance(
91116
Tensor<idx_t, 2, true>& outIndices,
92117
bool ignoreOutDistances = false);
93118

119+
// no bf16 support for AMD
120+
#ifndef USE_AMD_ROCM
121+
void runL2Distance(
122+
GpuResources* resources,
123+
cudaStream_t stream,
124+
Tensor<__nv_bfloat16, 2, true>& vectors,
125+
bool vectorsRowMajor,
126+
Tensor<float, 1, true>* vectorNorms,
127+
Tensor<__nv_bfloat16, 2, true>& queries,
128+
bool queriesRowMajor,
129+
int k,
130+
Tensor<float, 2, true>& outDistances,
131+
Tensor<idx_t, 2, true>& outIndices,
132+
bool ignoreOutDistances = false);
133+
#endif // USE_AMD_ROCM
134+
94135
/// Calculates brute-force inner product distance between `vectors`
95136
/// and `queries`, returning the k closest results seen
96137
void runIPDistance(
@@ -115,6 +156,20 @@ void runIPDistance(
115156
Tensor<float, 2, true>& outDistances,
116157
Tensor<idx_t, 2, true>& outIndices);
117158

159+
// no bf16 support for AMD
160+
#ifndef USE_AMD_ROCM
161+
void runIPDistance(
162+
GpuResources* resources,
163+
cudaStream_t stream,
164+
Tensor<__nv_bfloat16, 2, true>& vectors,
165+
bool vectorsRowMajor,
166+
Tensor<__nv_bfloat16, 2, true>& queries,
167+
bool queriesRowMajor,
168+
int k,
169+
Tensor<float, 2, true>& outDistances,
170+
Tensor<idx_t, 2, true>& outIndices);
171+
#endif // USE_AMD_ROCM
172+
118173
//
119174
// General distance implementation, assumes that all arguments are on the
120175
// device. This is the top-level internal distance function to call to dispatch

0 commit comments

Comments
 (0)