37
37
// / \param[in] device_id
38
38
// / the number of the device executing the operation
39
39
// /
40
+ // / \param[in] device_id
41
+ // / true if using hipBLASLt
42
+ // /
40
43
DeviceBatchedGemm::DeviceBatchedGemm (TypeConstant compute_type,
41
44
hipblasOperation_t op_a,
42
45
hipblasOperation_t op_b,
@@ -47,14 +50,16 @@ DeviceBatchedGemm::DeviceBatchedGemm(TypeConstant compute_type,
47
50
void const * alpha,
48
51
void const * beta,
49
52
double operations,
50
- int device_id)
53
+ int device_id,
54
+ bool lt)
51
55
: BatchedGemm(compute_type,
52
56
op_a, op_b,
53
57
a, b, c,
54
58
batch_count,
55
59
alpha, beta,
56
60
operations),
57
- device_id_(device_id)
61
+ device_id_(device_id),
62
+ lt_(lt)
58
63
{
59
64
// Set device, create stream.
60
65
HIP_CALL (hipSetDevice (device_id_));
@@ -64,12 +69,14 @@ DeviceBatchedGemm::DeviceBatchedGemm(TypeConstant compute_type,
64
69
HIPBLAS_CALL (hipblasCreate (&hipblas_handle_));
65
70
HIPBLAS_CALL (hipblasSetStream (hipblas_handle_, hip_stream_));
66
71
67
- // Create hipBLASLt handle and matmul descriptor.
68
- HIPBLASLT_CALL (hipblasLtCreate (&hipblaslt_handle_));
69
- HIPBLASLT_CALL (hipblasLtMatmulDescCreate (
70
- &hipblaslt_matmul_desc_,
71
- compute_type.compute_ ,
72
- c->type ().hip_ ));
72
+ if (lt) {
73
+ // Create hipBLASLt handle and matmul descriptor.
74
+ HIPBLASLT_CALL (hipblasLtCreate (&hipblaslt_handle_));
75
+ HIPBLASLT_CALL (hipblasLtMatmulDescCreate (
76
+ &hipblaslt_matmul_desc_,
77
+ compute_type.compute_ ,
78
+ c->type ().hip_ ));
79
+ }
73
80
74
81
// Create hipRAND generator, assign stream.
75
82
HIPRAND_CALL (hiprandCreateGenerator (&hiprand_generator_,
@@ -92,9 +99,11 @@ DeviceBatchedGemm::~DeviceBatchedGemm()
92
99
(void )hipSetDevice (device_id_);
93
100
94
101
// Destroy all the handles.
102
+ if (lt_) {
103
+ (void )hipblasLtMatmulDescDestroy (hipblaslt_matmul_desc_);
104
+ (void )hipblasLtDestroy (hipblaslt_handle_);
105
+ }
95
106
(void )hiprandDestroyGenerator (hiprand_generator_);
96
- (void )hipblasLtMatmulDescDestroy (hipblaslt_matmul_desc_);
97
- (void )hipblasLtDestroy (hipblaslt_handle_);
98
107
(void )hipblasDestroy (hipblas_handle_);
99
108
(void )hipStreamDestroy (hip_stream_);
100
109
0 commit comments