Skip to content

Commit 729e929

Browse files
committed
Properly implement getBitfield and GET_BITFIELD_U32/64 on ROCM.
Also put runtime errors in setBitfield for safeguard.
1 parent a6700a9 commit 729e929

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

faiss/gpu/utils/PtxUtils.cuh

+26-14
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,41 @@
77

88
#pragma once
99

10-
#include <cuda.h>
11-
#include <device_functions.h>
10+
#include <hip/hip_runtime.h>
11+
#include <hip/device_functions.h>
1212

1313
namespace faiss {
1414
namespace gpu {
1515

1616
#ifdef USE_ROCM
1717

18-
#define GET_BITFIELD_U32(OUT, VAL, POS, LEN)
19-
20-
#define GET_BITFIELD_U64(OUT, VAL, POS, LEN)
21-
22-
__device__ __forceinline__ unsigned int getBitfield(
23-
unsigned int val,
24-
int pos,
25-
int len) {
26-
unsigned int ret{0};
27-
return ret;
18+
#define GET_BITFIELD_U32(OUT, VAL, POS, LEN) \
19+
do { \
20+
OUT = getBitfield((uint32_t)VAL, POS, LEN); \
21+
} while (0)
22+
23+
#define GET_BITFIELD_U64(OUT, VAL, POS, LEN) \
24+
do { \
25+
OUT = getBitfield((uint64_t)VAL, POS, LEN); \
26+
} while (0)
27+
28+
// Taken from https://github.com/GPUOpen-ProfessionalCompute-Libraries/MIVisionX/blob/rocm-5.5.0/amd_openvx/openvx/ago/ago_util_opencl.cpp#L1563
29+
__device__ __forceinline__ uint32_t
30+
getBitfield(uint32_t val, int pos, int len) {
31+
if (len == 0)
32+
return 0;
33+
if (pos + len < 32)
34+
return (val << (32 - pos - len)) >> (32 - len);
35+
return val >> pos;
2836
}
2937

3038
__device__ __forceinline__ uint64_t
3139
getBitfield(uint64_t val, int pos, int len) {
32-
uint64_t ret{0};
33-
return ret;
40+
if (len == 0)
41+
return 0;
42+
if (pos + len < 64)
43+
return (val << (64 - pos - len)) >> (64 - len);
44+
return val >> pos;
3445
}
3546

3647
__device__ __forceinline__ unsigned int setBitfield(
@@ -39,6 +50,7 @@ __device__ __forceinline__ unsigned int setBitfield(
3950
int pos,
4051
int len) {
4152
unsigned int ret{0};
53+
printf("Runtime Error of %s: Unimplemented\n", __PRETTY_FUNCTION__);
4254
return ret;
4355
}
4456

0 commit comments

Comments
 (0)