Skip to content

Commit 15651f2

Browse files
Jeff Johnsonfacebook-github-bot
Jeff Johnson
authored andcommitted
Faiss GPU: bfloat16 brute-force kNN support (facebookresearch#4018)
Summary: This diff adds support for bfloat16 vector/query data types with the GPU brute-force k-nearest neighbor function (`bfKnn`). The change is largely just plumbing the new data type through the template hierarchy (so distances can be computed in bfloat16). Of note, by design, all final distance results are produced in float32 regardless of input data type (float32, float16, bfloat16). This is because the true nearest neighbors in many data sets can often differ by only ~1000 float32 ULPs in terms of distance which will result in possible false equivalency. This seems to be one area where lossy compression/quantization thoughout does not work as well (and is also why `CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION` is set in `StandardGpuResources.cpp`. However, given that there is native bf16 x bf16 = fp32 tensor core support on Ampere+ architectures, the matrix multiplication itself should use them. As bfloat16 support is quite lacking on AMD/ROCm (see [here](https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Device_API_supported_by_HIP.html), very few bf16 functions implemented), bf16 functionality is completely disabled / not compiled for AMD ROCm. Reviewed By: mdouze Differential Revision: D65459723
1 parent 3c25a68 commit 15651f2

20 files changed

+913
-90
lines changed

contrib/torch_utils.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ def swig_ptr_from_FloatTensor(x):
5656
return faiss.cast_integer_to_float_ptr(
5757
x.untyped_storage().data_ptr() + x.storage_offset() * 4)
5858

59+
def swig_ptr_from_BFloat16Tensor(x):
60+
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
61+
assert x.is_contiguous()
62+
assert x.dtype == torch.bfloat16
63+
return faiss.cast_integer_to_void_ptr(
64+
x.untyped_storage().data_ptr() + x.storage_offset() * 2)
65+
5966

6067
def swig_ptr_from_IntTensor(x):
6168
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
@@ -606,8 +613,11 @@ def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRI
606613
elif xb.dtype == torch.float16:
607614
xb_type = faiss.DistanceDataType_F16
608615
xb_ptr = swig_ptr_from_HalfTensor(xb)
616+
elif xb.dtype == torch.bfloat16:
617+
xb_type = faiss.DistanceDataType_BF16
618+
xb_ptr = swig_ptr_from_BFloat16Tensor(xb)
609619
else:
610-
raise TypeError('xb must be f32 or f16')
620+
raise TypeError('xq must be float32, float16 or bfloat16')
611621

612622
nq, d2 = xq.size()
613623
assert d2 == d
@@ -625,8 +635,11 @@ def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRI
625635
elif xq.dtype == torch.float16:
626636
xq_type = faiss.DistanceDataType_F16
627637
xq_ptr = swig_ptr_from_HalfTensor(xq)
638+
elif xq.dtype == torch.bfloat16:
639+
xq_type = faiss.DistanceDataType_BF16
640+
xq_ptr = swig_ptr_from_BFloat16Tensor(xq)
628641
else:
629-
raise TypeError('xq must be f32 or f16')
642+
raise TypeError('xq must be float32, float16 or bfloat16')
630643

631644
if D is None:
632645
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)

faiss/gpu/GpuDistance.cu

+21-5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <faiss/gpu/utils/ConversionOperators.cuh>
3131
#include <faiss/gpu/utils/CopyUtils.cuh>
3232
#include <faiss/gpu/utils/DeviceTensor.cuh>
33+
#include <faiss/gpu/utils/Float16.cuh>
3334
#include <optional>
3435

3536
#if defined USE_NVIDIA_CUVS
@@ -231,7 +232,7 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
231232
FAISS_THROW_IF_NOT_MSG(
232233
args.vectorType == args.queryType,
233234
"limitation: both vectorType and queryType must currently "
234-
"be the same (F32 or F16");
235+
"be the same (F32 / F16 / BF16");
235236

236237
#if defined USE_NVIDIA_CUVS
237238
// Note: For now, cuVS bfknn requires queries and vectors to be same layout
@@ -400,6 +401,17 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
400401
bfKnnConvert<float>(prov, args);
401402
} else if (args.vectorType == DistanceDataType::F16) {
402403
bfKnnConvert<half>(prov, args);
404+
} else if (args.vectorType == DistanceDataType::BF16) {
405+
// no bf16 support for AMD
406+
#ifndef USE_AMD_ROCM
407+
if (prov->getResources()->supportsBFloat16CurrentDevice()) {
408+
bfKnnConvert<__nv_bfloat16>(prov, args);
409+
} else {
410+
FAISS_THROW_MSG("not compiled with bfloat16 support");
411+
}
412+
#else
413+
FAISS_THROW_MSG("no AMD bfloat16 support");
414+
#endif
403415
} else {
404416
FAISS_THROW_MSG("unknown vectorType");
405417
}
@@ -466,8 +478,10 @@ void bfKnn_single_query_shard(
466478
args.k > 0,
467479
"bfKnn_tiling: tiling vectors is only supported for k > 0");
468480
size_t distance_size = args.vectorType == DistanceDataType::F32 ? 4
469-
: args.vectorType == DistanceDataType::F16 ? 2
470-
: 0;
481+
: (args.vectorType == DistanceDataType::F16 ||
482+
args.vectorType == DistanceDataType::BF16)
483+
? 2
484+
: 0;
471485
FAISS_THROW_IF_NOT_MSG(
472486
distance_size > 0, "bfKnn_tiling: unknown vectorType");
473487
size_t shard_size = vectorsMemoryLimit / (args.dims * distance_size);
@@ -524,8 +538,10 @@ void bfKnn_tiling(
524538
args.k > 0,
525539
"bfKnn_tiling: tiling queries is only supported for k > 0");
526540
size_t distance_size = args.queryType == DistanceDataType::F32 ? 4
527-
: args.queryType == DistanceDataType::F16 ? 2
528-
: 0;
541+
: (args.queryType == DistanceDataType::F16 ||
542+
args.queryType == DistanceDataType::BF16)
543+
? 2
544+
: 0;
529545
FAISS_THROW_IF_NOT_MSG(
530546
distance_size > 0, "bfKnn_tiling: unknown queryType");
531547
size_t label_size = args.outIndicesType == IndicesDataType::I64 ? 8

faiss/gpu/GpuDistance.h

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class GpuResourcesProvider;
1919
enum class DistanceDataType {
2020
F32 = 1,
2121
F16,
22+
BF16,
2223
};
2324

2425
// Scalar type of the indices data

faiss/gpu/GpuResources.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ GpuMemoryReservation::~GpuMemoryReservation() {
161161

162162
GpuResources::~GpuResources() = default;
163163

164+
bool GpuResources::supportsBFloat16CurrentDevice() {
165+
return supportsBFloat16(getCurrentDevice());
166+
}
167+
164168
cublasHandle_t GpuResources::getBlasHandleCurrentDevice() {
165169
return getBlasHandle(getCurrentDevice());
166170
}

faiss/gpu/GpuResources.h

+6
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ class GpuResources {
205205
/// of demand
206206
virtual void initializeForDevice(int device) = 0;
207207

208+
/// Does the given GPU support bfloat16?
209+
virtual bool supportsBFloat16(int device) = 0;
210+
208211
/// Returns the cuBLAS handle that we use for the given device
209212
virtual cublasHandle_t getBlasHandle(int device) = 0;
210213

@@ -252,6 +255,9 @@ class GpuResources {
252255
/// Functions provided by default
253256
///
254257

258+
/// Does the current GPU support bfloat16?
259+
bool supportsBFloat16CurrentDevice();
260+
255261
/// Calls getBlasHandle with the current device
256262
cublasHandle_t getBlasHandleCurrentDevice();
257263

faiss/gpu/StandardGpuResources.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,13 @@ size_t StandardGpuResourcesImpl::getDefaultTempMemForGPU(
206206
return requested;
207207
}
208208

209+
/// Does the given GPU support bfloat16?
210+
bool StandardGpuResourcesImpl::supportsBFloat16(int device) {
211+
initializeForDevice(device);
212+
auto& prop = getDeviceProperties(device);
213+
return prop.major >= 8;
214+
}
215+
209216
void StandardGpuResourcesImpl::noTempMemory() {
210217
setTempMemory(0);
211218
}
@@ -701,6 +708,14 @@ std::shared_ptr<GpuResources> StandardGpuResources::getResources() {
701708
return res_;
702709
}
703710

711+
bool StandardGpuResources::supportsBFloat16(int device) {
712+
return res_->supportsBFloat16(device);
713+
}
714+
715+
bool StandardGpuResources::supportsBFloat16CurrentDevice() {
716+
return res_->supportsBFloat16CurrentDevice();
717+
}
718+
704719
void StandardGpuResources::noTempMemory() {
705720
res_->noTempMemory();
706721
}

faiss/gpu/StandardGpuResources.h

+9
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class StandardGpuResourcesImpl : public GpuResources {
4848

4949
~StandardGpuResourcesImpl() override;
5050

51+
/// Does the given GPU support bfloat16?
52+
bool supportsBFloat16(int device) override;
53+
5154
/// Disable allocation of temporary memory; all temporary memory
5255
/// requests will call cudaMalloc / cudaFree at the point of use
5356
void noTempMemory();
@@ -199,6 +202,12 @@ class StandardGpuResources : public GpuResourcesProvider {
199202

200203
std::shared_ptr<GpuResources> getResources() override;
201204

205+
/// Whether or not the given device supports native bfloat16 arithmetic
206+
bool supportsBFloat16(int device);
207+
208+
/// Whether or not the current device supports native bfloat16 arithmetic
209+
bool supportsBFloat16CurrentDevice();
210+
202211
/// Disable allocation of temporary memory; all temporary memory
203212
/// requests will call cudaMalloc / cudaFree at the point of use
204213
void noTempMemory();

faiss/gpu/impl/Distance.cu

+101
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,30 @@ void runAllPairwiseL2Distance(
504504
outDistances);
505505
}
506506

507+
// no bf16 support for AMD
508+
#ifndef USE_AMD_ROCM
509+
void runAllPairwiseL2Distance(
510+
GpuResources* res,
511+
cudaStream_t stream,
512+
Tensor<__nv_bfloat16, 2, true>& vectors,
513+
bool vectorsRowMajor,
514+
Tensor<float, 1, true>* vectorNorms,
515+
Tensor<__nv_bfloat16, 2, true>& queries,
516+
bool queriesRowMajor,
517+
Tensor<float, 2, true>& outDistances) {
518+
runAllPairwiseDistance<__nv_bfloat16>(
519+
true,
520+
res,
521+
stream,
522+
vectors,
523+
vectorsRowMajor,
524+
vectorNorms,
525+
queries,
526+
queriesRowMajor,
527+
outDistances);
528+
}
529+
#endif // USE_AMD_ROCM
530+
507531
void runAllPairwiseIPDistance(
508532
GpuResources* res,
509533
cudaStream_t stream,
@@ -544,6 +568,29 @@ void runAllPairwiseIPDistance(
544568
outDistances);
545569
}
546570

571+
// no bf16 support for AMD
572+
#ifndef USE_AMD_ROCM
573+
void runAllPairwiseIPDistance(
574+
GpuResources* res,
575+
cudaStream_t stream,
576+
Tensor<__nv_bfloat16, 2, true>& vectors,
577+
bool vectorsRowMajor,
578+
Tensor<__nv_bfloat16, 2, true>& queries,
579+
bool queriesRowMajor,
580+
Tensor<float, 2, true>& outDistances) {
581+
runAllPairwiseDistance<__nv_bfloat16>(
582+
false,
583+
res,
584+
stream,
585+
vectors,
586+
vectorsRowMajor,
587+
nullptr,
588+
queries,
589+
queriesRowMajor,
590+
outDistances);
591+
}
592+
#endif // USE_AMD_ROCM
593+
547594
void runL2Distance(
548595
GpuResources* res,
549596
cudaStream_t stream,
@@ -596,6 +643,35 @@ void runL2Distance(
596643
ignoreOutDistances);
597644
}
598645

646+
// no bf16 support for AMD
647+
#ifndef USE_AMD_ROCM
648+
void runL2Distance(
649+
GpuResources* res,
650+
cudaStream_t stream,
651+
Tensor<__nv_bfloat16, 2, true>& vectors,
652+
bool vectorsRowMajor,
653+
Tensor<float, 1, true>* vectorNorms,
654+
Tensor<__nv_bfloat16, 2, true>& queries,
655+
bool queriesRowMajor,
656+
int k,
657+
Tensor<float, 2, true>& outDistances,
658+
Tensor<idx_t, 2, true>& outIndices,
659+
bool ignoreOutDistances) {
660+
runL2Distance<__nv_bfloat16>(
661+
res,
662+
stream,
663+
vectors,
664+
vectorsRowMajor,
665+
vectorNorms,
666+
queries,
667+
queriesRowMajor,
668+
k,
669+
outDistances,
670+
outIndices,
671+
ignoreOutDistances);
672+
}
673+
#endif // USE_AMD_ROCM
674+
599675
void runIPDistance(
600676
GpuResources* res,
601677
cudaStream_t stream,
@@ -640,5 +716,30 @@ void runIPDistance(
640716
outIndices);
641717
}
642718

719+
// no bf16 support for AMD
720+
#ifndef USE_AMD_ROCM
721+
void runIPDistance(
722+
GpuResources* res,
723+
cudaStream_t stream,
724+
Tensor<__nv_bfloat16, 2, true>& vectors,
725+
bool vectorsRowMajor,
726+
Tensor<__nv_bfloat16, 2, true>& queries,
727+
bool queriesRowMajor,
728+
int k,
729+
Tensor<float, 2, true>& outDistances,
730+
Tensor<idx_t, 2, true>& outIndices) {
731+
runIPDistance<__nv_bfloat16>(
732+
res,
733+
stream,
734+
vectors,
735+
vectorsRowMajor,
736+
queries,
737+
queriesRowMajor,
738+
k,
739+
outDistances,
740+
outIndices);
741+
}
742+
#endif // USE_AMD_ROCM
743+
643744
} // namespace gpu
644745
} // namespace faiss

0 commit comments

Comments
 (0)