Skip to content

Commit f6bcd02

Browse files
committed
Fix handling of hibBLASLt
1 parent 8e78d45 commit f6bcd02

5 files changed

+39
-15
lines changed

src/BatchedGemm.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class BatchedGemm {
9494
int batch_count,
9595
void const* alpha,
9696
void const* beta,
97-
int device_id = 0);
97+
int device_id,
98+
bool lt);
9899

99100
static void makeDevices(std::string type_a_name,
100101
std::string type_b_name,
@@ -107,6 +108,7 @@ class BatchedGemm {
107108
int batch_count,
108109
void const* alpha,
109110
void const* beta,
111+
bool lt,
110112
bool host_a,
111113
bool host_b,
112114
bool host_c,

src/BatchedGemm.hip.cpp

+13-3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
/// \param[in] device_id
3939
/// the number of the GPU executing the matrix multiplication
4040
///
41+
/// \param[in] lt
42+
/// If true, indicates the use of hipBLASLt.
43+
///
4144
BatchedGemm*
4245
BatchedGemm::make(std::string type_a_name,
4346
std::string type_b_name,
@@ -50,7 +53,8 @@ BatchedGemm::make(std::string type_a_name,
5053
int batch_count,
5154
void const* alpha,
5255
void const* beta,
53-
int device_id)
56+
int device_id,
57+
bool lt)
5458
{
5559
hipblasOperation_t op_a = stringToOp(op_a_name);
5660
hipblasOperation_t op_b = stringToOp(op_b_name);
@@ -83,7 +87,8 @@ BatchedGemm::make(std::string type_a_name,
8387
batch_count,
8488
alpha, beta,
8589
operations(stringToType(type_c_name), m, n, k),
86-
device_id);
90+
device_id,
91+
lt);
8792
}
8893

8994
//------------------------------------------------------------------------------
@@ -93,6 +98,9 @@ BatchedGemm::make(std::string type_a_name,
9398
/// \remark
9499
/// Not listing parameters common with BatchedGemm::make().
95100
///
101+
/// \param[in] lt
102+
/// If true, indicates the use of hipBLASLt.
103+
///
96104
/// \param[in] host_a, host_b, host_c
97105
/// If true, indicates that the array is stored in host memory.
98106
///
@@ -117,6 +125,7 @@ BatchedGemm::makeDevices(std::string type_a_name,
117125
int batch_count,
118126
void const* alpha,
119127
void const* beta,
128+
bool lt,
120129
bool host_a,
121130
bool host_b,
122131
bool host_c,
@@ -179,6 +188,7 @@ BatchedGemm::makeDevices(std::string type_a_name,
179188
batch_count,
180189
alpha, beta,
181190
operations(stringToType(type_c_name), m, n, k),
182-
device_id);
191+
device_id,
192+
lt);
183193
}
184194
}

src/DeviceBatchedGemm.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class DeviceBatchedGemm: public BatchedGemm {
5050
void const* alpha,
5151
void const* beta,
5252
double operations,
53-
int device_id = 0);
53+
int device_id,
54+
bool lt);
5455
~DeviceBatchedGemm();
5556

5657
/// Populates the batch with random data.
@@ -88,6 +89,7 @@ class DeviceBatchedGemm: public BatchedGemm {
8889
void runBatchedGemmLt();
8990

9091
int device_id_; ///< the number of the device executing the operation
92+
bool lt_; ///< true if using hipBLASLt
9193

9294
hipStream_t hip_stream_; ///< stream
9395
hipblasHandle_t hipblas_handle_; ///< hipBLAS handle

src/DeviceBatchedGemm.hip.cpp

+19-10
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
/// \param[in] device_id
3838
/// the number of the device executing the operation
3939
///
40+
/// \param[in] device_id
41+
/// true if using hipBLASLt
42+
///
4043
DeviceBatchedGemm::DeviceBatchedGemm(TypeConstant compute_type,
4144
hipblasOperation_t op_a,
4245
hipblasOperation_t op_b,
@@ -47,14 +50,16 @@ DeviceBatchedGemm::DeviceBatchedGemm(TypeConstant compute_type,
4750
void const* alpha,
4851
void const* beta,
4952
double operations,
50-
int device_id)
53+
int device_id,
54+
bool lt)
5155
: BatchedGemm(compute_type,
5256
op_a, op_b,
5357
a, b, c,
5458
batch_count,
5559
alpha, beta,
5660
operations),
57-
device_id_(device_id)
61+
device_id_(device_id),
62+
lt_(lt)
5863
{
5964
// Set device, create stream.
6065
HIP_CALL(hipSetDevice(device_id_));
@@ -64,12 +69,14 @@ DeviceBatchedGemm::DeviceBatchedGemm(TypeConstant compute_type,
6469
HIPBLAS_CALL(hipblasCreate(&hipblas_handle_));
6570
HIPBLAS_CALL(hipblasSetStream(hipblas_handle_, hip_stream_));
6671

67-
// Create hipBLASLt handle and matmul descriptor.
68-
HIPBLASLT_CALL(hipblasLtCreate(&hipblaslt_handle_));
69-
HIPBLASLT_CALL(hipblasLtMatmulDescCreate(
70-
&hipblaslt_matmul_desc_,
71-
compute_type.compute_,
72-
c->type().hip_));
72+
if (lt) {
73+
// Create hipBLASLt handle and matmul descriptor.
74+
HIPBLASLT_CALL(hipblasLtCreate(&hipblaslt_handle_));
75+
HIPBLASLT_CALL(hipblasLtMatmulDescCreate(
76+
&hipblaslt_matmul_desc_,
77+
compute_type.compute_,
78+
c->type().hip_));
79+
}
7380

7481
// Create hipRAND generator, assign stream.
7582
HIPRAND_CALL(hiprandCreateGenerator(&hiprand_generator_,
@@ -92,9 +99,11 @@ DeviceBatchedGemm::~DeviceBatchedGemm()
9299
(void)hipSetDevice(device_id_);
93100

94101
// Destroy all the handles.
102+
if (lt_) {
103+
(void)hipblasLtMatmulDescDestroy(hipblaslt_matmul_desc_);
104+
(void)hipblasLtDestroy(hipblaslt_handle_);
105+
}
95106
(void)hiprandDestroyGenerator(hiprand_generator_);
96-
(void)hipblasLtMatmulDescDestroy(hipblaslt_matmul_desc_);
97-
(void)hipblasLtDestroy(hipblaslt_handle_);
98107
(void)hipblasDestroy(hipblas_handle_);
99108
(void)hipStreamDestroy(hip_stream_);
100109

src/gemm.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ void run(int argc, char** argv)
169169
std::atoi(argv[12]), // ldc
170170
std::atoi(argv[13]), // batch count
171171
alpha, beta,
172+
lt, // hipBLASLt?
172173
host_a, host_b, host_c,
173174
coherent_a, coherent_b, coherent_c,
174175
shared_a, shared_b,

0 commit comments

Comments
 (0)