Skip to content

Commit 37f52dc

Browse files
ItsPittfacebook-github-bot
authored andcommitted
ROCm support for bfloat16 (#4039)
Summary: Updated the hipify script to handle bfloat16 conversion and unblocked disabling for ROCm. Pull Request resolved: #4039 Reviewed By: junjieqi Differential Revision: D66413190 Pulled By: asadoughi fbshipit-source-id: d1564f87e3c3466ff929dfd639bd544318371148
1 parent d1ae64e commit 37f52dc

10 files changed

+51
-67
lines changed

faiss/gpu/GpuDistance.cu

-5
Original file line numberDiff line numberDiff line change
@@ -402,16 +402,11 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
402402
} else if (args.vectorType == DistanceDataType::F16) {
403403
bfKnnConvert<half>(prov, args);
404404
} else if (args.vectorType == DistanceDataType::BF16) {
405-
// no bf16 support for AMD
406-
#ifndef USE_AMD_ROCM
407405
if (prov->getResources()->supportsBFloat16CurrentDevice()) {
408406
bfKnnConvert<__nv_bfloat16>(prov, args);
409407
} else {
410408
FAISS_THROW_MSG("not compiled with bfloat16 support");
411409
}
412-
#else
413-
FAISS_THROW_MSG("no AMD bfloat16 support");
414-
#endif
415410
} else {
416411
FAISS_THROW_MSG("unknown vectorType");
417412
}

faiss/gpu/hipify.sh

+39-4
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,46 @@
33
#
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Usage: ./gpu/hipify.sh
8+
#
69

710
function hipify_dir()
811
{
912
# print dir name
1013
cd "$1" || exit
1114
echo "Hipifying $(pwd)"
1215

16+
if [ -d ./gpu-tmp ]; then
17+
#Clearing out any leftover files and directories
18+
echo "Removing old ./gpu-tmp"
19+
rm -rf ./gpu-tmp
20+
fi
21+
22+
if [ -d ./gpu ]; then
23+
#Making a temp directory to implement pre hipify rules
24+
echo "Creating ./gpu-tmp"
25+
cp -r ./gpu ./gpu-tmp
26+
27+
# adjust __nv_bfloat162 before hipify because of inaccurate conversions
28+
# adjust __nv_bfloat16 before hipify because of inaccurate conversions
29+
for ext in hip cuh h cpp c cu cuh
30+
do
31+
while IFS= read -r -d '' src
32+
do
33+
sed -i 's@__nv_bfloat162@__hip_bfloat162@' "$src"
34+
sed -i 's@__nv_bfloat16@__hip_bfloat16@' "$src"
35+
done < <(find ./gpu-tmp -name "*.$ext" -print0)
36+
done
37+
else
38+
echo "Can't find the gpu/ dir"
39+
exit
40+
fi
41+
1342
# create all destination directories for hipified files into sibling 'gpu-rocm' directory
1443
while IFS= read -r -d '' src
1544
do
16-
dst="${src//gpu/gpu-rocm}"
45+
dst="${src//gpu-tmp/gpu-rocm}"
1746

1847
if [ -d $dst ]; then
1948
#Clearing out any leftover files and directories
@@ -24,17 +53,17 @@ function hipify_dir()
2453
#Making directories
2554
echo "Creating $dst"
2655
mkdir -p "$dst"
27-
done < <(find ./gpu -type d -print0)
56+
done < <(find ./gpu-tmp -type d -print0)
2857

2958
# run hipify-perl against all *.cu *.cuh *.h *.cpp files, no renaming
3059
# run all files in parallel to speed up
3160
for ext in cu cuh h cpp c
3261
do
3362
while IFS= read -r -d '' src
3463
do
35-
dst="${src//\.\/gpu/\.\/gpu-rocm}"
64+
dst="${src//\.\/gpu-tmp/\.\/gpu-rocm}"
3665
hipify-perl -o="$dst.tmp" "$src" &
37-
done < <(find ./gpu -name "*.$ext" -print0)
66+
done < <(find ./gpu-tmp -name "*.$ext" -print0)
3867
done
3968
wait
4069

@@ -45,6 +74,12 @@ function hipify_dir()
4574
mv "$src" "$dst"
4675
done < <(find ./gpu-rocm -name "*.cu.tmp" -print0)
4776

77+
if [ -d ./gpu-tmp ]; then
78+
#Clearing out any leftover files and directories
79+
echo "Removing ./gpu-tmp"
80+
rm -rf ./gpu-tmp
81+
fi
82+
4883
# replace header include statements "<faiss/gpu/" with "<faiss/gpu-rocm"
4984
# replace thrust::cuda::par with thrust::hip::par
5085
# adjust header path location for hipblas.h to avoid unnecessary deprecation warnings

faiss/gpu/impl/Distance.cu

-12
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,6 @@ void runAllPairwiseL2Distance(
504504
outDistances);
505505
}
506506

507-
// no bf16 support for AMD
508-
#ifndef USE_AMD_ROCM
509507
void runAllPairwiseL2Distance(
510508
GpuResources* res,
511509
cudaStream_t stream,
@@ -526,7 +524,6 @@ void runAllPairwiseL2Distance(
526524
queriesRowMajor,
527525
outDistances);
528526
}
529-
#endif // USE_AMD_ROCM
530527

531528
void runAllPairwiseIPDistance(
532529
GpuResources* res,
@@ -568,8 +565,6 @@ void runAllPairwiseIPDistance(
568565
outDistances);
569566
}
570567

571-
// no bf16 support for AMD
572-
#ifndef USE_AMD_ROCM
573568
void runAllPairwiseIPDistance(
574569
GpuResources* res,
575570
cudaStream_t stream,
@@ -589,7 +584,6 @@ void runAllPairwiseIPDistance(
589584
queriesRowMajor,
590585
outDistances);
591586
}
592-
#endif // USE_AMD_ROCM
593587

594588
void runL2Distance(
595589
GpuResources* res,
@@ -643,8 +637,6 @@ void runL2Distance(
643637
ignoreOutDistances);
644638
}
645639

646-
// no bf16 support for AMD
647-
#ifndef USE_AMD_ROCM
648640
void runL2Distance(
649641
GpuResources* res,
650642
cudaStream_t stream,
@@ -670,7 +662,6 @@ void runL2Distance(
670662
outIndices,
671663
ignoreOutDistances);
672664
}
673-
#endif // USE_AMD_ROCM
674665

675666
void runIPDistance(
676667
GpuResources* res,
@@ -716,8 +707,6 @@ void runIPDistance(
716707
outIndices);
717708
}
718709

719-
// no bf16 support for AMD
720-
#ifndef USE_AMD_ROCM
721710
void runIPDistance(
722711
GpuResources* res,
723712
cudaStream_t stream,
@@ -739,7 +728,6 @@ void runIPDistance(
739728
outDistances,
740729
outIndices);
741730
}
742-
#endif // USE_AMD_ROCM
743731

744732
} // namespace gpu
745733
} // namespace faiss

faiss/gpu/impl/Distance.cuh

-12
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ void runAllPairwiseL2Distance(
4141
bool queriesRowMajor,
4242
Tensor<float, 2, true>& outDistances);
4343

44-
// no bf16 support for AMD
45-
#ifndef USE_AMD_ROCM
4644
void runAllPairwiseL2Distance(
4745
GpuResources* res,
4846
cudaStream_t stream,
@@ -52,7 +50,6 @@ void runAllPairwiseL2Distance(
5250
Tensor<__nv_bfloat16, 2, true>& queries,
5351
bool queriesRowMajor,
5452
Tensor<float, 2, true>& outDistances);
55-
#endif // USE_AMD_ROCM
5653

5754
void runAllPairwiseIPDistance(
5855
GpuResources* res,
@@ -72,8 +69,6 @@ void runAllPairwiseIPDistance(
7269
bool queriesRowMajor,
7370
Tensor<float, 2, true>& outDistances);
7471

75-
// no bf16 support for AMD
76-
#ifndef USE_AMD_ROCM
7772
void runAllPairwiseIPDistance(
7873
GpuResources* res,
7974
cudaStream_t stream,
@@ -82,7 +77,6 @@ void runAllPairwiseIPDistance(
8277
Tensor<__nv_bfloat16, 2, true>& queries,
8378
bool queriesRowMajor,
8479
Tensor<float, 2, true>& outDistances);
85-
#endif // USE_AMD_ROCM
8680

8781
/// Calculates brute-force L2 distance between `vectors` and
8882
/// `queries`, returning the k closest results seen
@@ -116,8 +110,6 @@ void runL2Distance(
116110
Tensor<idx_t, 2, true>& outIndices,
117111
bool ignoreOutDistances = false);
118112

119-
// no bf16 support for AMD
120-
#ifndef USE_AMD_ROCM
121113
void runL2Distance(
122114
GpuResources* resources,
123115
cudaStream_t stream,
@@ -130,7 +122,6 @@ void runL2Distance(
130122
Tensor<float, 2, true>& outDistances,
131123
Tensor<idx_t, 2, true>& outIndices,
132124
bool ignoreOutDistances = false);
133-
#endif // USE_AMD_ROCM
134125

135126
/// Calculates brute-force inner product distance between `vectors`
136127
/// and `queries`, returning the k closest results seen
@@ -156,8 +147,6 @@ void runIPDistance(
156147
Tensor<float, 2, true>& outDistances,
157148
Tensor<idx_t, 2, true>& outIndices);
158149

159-
// no bf16 support for AMD
160-
#ifndef USE_AMD_ROCM
161150
void runIPDistance(
162151
GpuResources* resources,
163152
cudaStream_t stream,
@@ -168,7 +157,6 @@ void runIPDistance(
168157
int k,
169158
Tensor<float, 2, true>& outDistances,
170159
Tensor<idx_t, 2, true>& outIndices);
171-
#endif // USE_AMD_ROCM
172160

173161
//
174162
// General distance implementation, assumes that all arguments are on the

faiss/gpu/impl/L2Norm.cu

-3
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,6 @@ void runL2Norm(
275275
runL2Norm<half, half2>(input, inputRowMajor, output, normSquared, stream);
276276
}
277277

278-
// no bf16 support for AMD
279-
#ifndef USE_AMD_ROCM
280278
void runL2Norm(
281279
Tensor<__nv_bfloat16, 2, true>& input,
282280
bool inputRowMajor,
@@ -286,7 +284,6 @@ void runL2Norm(
286284
runL2Norm<__nv_bfloat16, __nv_bfloat162>(
287285
input, inputRowMajor, output, normSquared, stream);
288286
}
289-
#endif
290287

291288
} // namespace gpu
292289
} // namespace faiss

faiss/gpu/impl/L2Norm.cuh

-3
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,12 @@ void runL2Norm(
2727
bool normSquared,
2828
cudaStream_t stream);
2929

30-
// no bf16 support for AMD
31-
#ifndef USE_AMD_ROCM
3230
void runL2Norm(
3331
Tensor<__nv_bfloat16, 2, true>& input,
3432
bool inputRowMajor,
3533
Tensor<float, 1, true>& output,
3634
bool normSquared,
3735
cudaStream_t stream);
38-
#endif
3936

4037
} // namespace gpu
4138
} // namespace faiss

faiss/gpu/utils/ConversionOperators.cuh

-8
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,9 @@ struct ConvertTo<float> {
3838
static inline __device__ float to(half v) {
3939
return __half2float(v);
4040
}
41-
42-
#ifndef USE_AMD_ROCM
4341
static inline __device__ float to(__nv_bfloat16 v) {
4442
return __bfloat162float(v);
4543
}
46-
#endif // !USE_AMD_ROCM
4744
};
4845

4946
template <>
@@ -96,9 +93,6 @@ struct ConvertTo<Half4> {
9693
}
9794
};
9895

99-
// no bf16 support for AMD
100-
#ifndef USE_AMD_ROCM
101-
10296
template <>
10397
struct ConvertTo<__nv_bfloat16> {
10498
static inline __device__ __nv_bfloat16 to(float v) {
@@ -112,8 +106,6 @@ struct ConvertTo<__nv_bfloat16> {
112106
}
113107
};
114108

115-
#endif // USE_AMD_ROCM
116-
117109
template <typename From, typename To>
118110
struct Convert {
119111
inline __device__ To operator()(From v) const {

faiss/gpu/utils/Float16.cuh

+7-10
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,22 @@
1212
#include <faiss/gpu/utils/DeviceUtils.h>
1313

1414
// Some compute capabilities have full float16 ALUs.
15-
#if __CUDA_ARCH__ >= 530 || defined(USE_AMD_ROCM)
15+
#if __CUDA_ARCH__ >= 530
1616
#define FAISS_USE_FULL_FLOAT16 1
1717
#endif // __CUDA_ARCH__ types
1818

1919
// Some compute capabilities have full bfloat16 ALUs.
20-
// FIXME: no support in ROCm yet
21-
#if __CUDA_ARCH__ >= 800 // || defined(USE_AMD_ROCM)
20+
#if __CUDA_ARCH__ >= 800 || defined(USE_AMD_ROCM)
2221
#define FAISS_USE_FULL_BFLOAT16 1
2322
#endif // __CUDA_ARCH__ types
2423

25-
#include <cuda_fp16.h>
2624
#if !defined(USE_AMD_ROCM)
2725
#include <cuda_bf16.h>
28-
#endif
29-
// #else
30-
// FIXME: no support in ROCm yet
31-
// #include <amd_hip_bf16.h>
32-
// #include <amd_hip_fp16.h>
33-
// #endif // !defined(USE_AMD_ROCM)
26+
#include <cuda_fp16.h>
27+
#else
28+
#include <hip/hip_bf16.h>
29+
#include <hip/hip_fp16.h>
30+
#endif // !defined(USE_AMD_ROCM)
3431

3532
namespace faiss {
3633
namespace gpu {

faiss/gpu/utils/MathOperators.cuh

+1-5
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,6 @@ struct Math<Half8> {
556556
}
557557
};
558558

559-
#ifndef USE_AMD_ROCM
560-
561559
template <>
562560
struct Math<__nv_bfloat16> {
563561
typedef __nv_bfloat16 ScalarType;
@@ -626,7 +624,7 @@ struct Math<__nv_bfloat16> {
626624
}
627625

628626
static inline __device__ __nv_bfloat16 zero() {
629-
#if CUDA_VERSION >= 9000
627+
#if CUDA_VERSION >= 9000 || defined(USE_AMD_ROCM)
630628
return 0.0f;
631629
#else
632630
__nv_bfloat16 h;
@@ -789,7 +787,5 @@ struct Math<__nv_bfloat162> {
789787
}
790788
};
791789

792-
#endif // !USE_AMD_ROCM
793-
794790
} // namespace gpu
795791
} // namespace faiss

faiss/gpu/utils/MatrixMult-inl.cuh

+4-5
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,10 @@ struct GetCudaType<half> {
3232
static constexpr hipblasDatatype_t Type = HIPBLAS_R_16F;
3333
};
3434

35-
// FIXME: no AMD support for bf16
36-
// template <>
37-
// struct GetCudaType<__nv_bfloat16> {
38-
// static constexpr hipblasDatatype_t Type = HIPBLAS_R_16B;
39-
// };
35+
template <>
36+
struct GetCudaType<__hip_bfloat16> {
37+
static constexpr hipblasDatatype_t Type = HIPBLAS_R_16B;
38+
};
4039

4140
#else
4241

0 commit comments

Comments
 (0)