Skip to content

Commit

Permalink
Batched gemm (#633)
Browse files Browse the repository at this point in the history
* Use cblas_sgemm_batch when available
* Merge with master, add comments and describe contribution
  • Loading branch information
XapaJIaMnu authored and ugermann committed May 20, 2020
1 parent 3f7b459 commit 9ae1951
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]

### Added
- Use *cblas_sgemm_batch* in stead of a for loop of *cblas_sgemm* on CPU as the batched_gemm implementation
- Supporting relative paths in shortlist and sqlite options
- Training and scoring from STDIN
- Support for reading from TSV files from STDIN and other sources during training
Expand Down
61 changes: 61 additions & 0 deletions src/tensors/cpu/prod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,66 @@ void ProdBatched(marian::Tensor C,
auto strideC = n * m;

auto batchC = std::max(batchA, batchB);
#if MKL_FOUND
CBLAS_TRANSPOSE transA_forarr = CblasNoTrans;
CBLAS_TRANSPOSE transB_forarr = CblasNoTrans;

if(transA)
transA_forarr = CblasTrans;

if(transB)
transB_forarr = CblasTrans;

/* cblas_sgemm_batch allows us to group all the small GEMMs that are done in a for loop with sgemm and compute
* them in only one MKL call. For the API documentation refer to
* https://software.intel.com/content/www/us/en/develop/documentation/mkl-developer-reference-c/top/blas-and-sparse-blas-routines/blas-like-extensions/cblas-gemm-batch.html
* The API supports dependencies, where you can specify one "group" of GEMMs to be computed after another. (This controlled by the group_count parameter).
* In our case, the operations are not dependent on one another so we hardcode one group. The rest of the arguments (with the exception of group_size) are
* the same as the ones that cblas_sgemm expects, with the difference that we are supposed to provide an array pointer (One element per group).
* Weirdly enough, we are required to to provide all of the integer arguments as the MKL_INT datatype
*/

static const constexpr size_t group_count = 1; // We have one group
const std::vector<CBLAS_TRANSPOSE> transa_arr(group_count, transA_forarr);
const std::vector<CBLAS_TRANSPOSE> transb_arr(group_count, transB_forarr);
const std::vector<MKL_INT> m_arr(group_count, (MKL_INT)m);
const std::vector<MKL_INT> n_arr(group_count, (MKL_INT)n);
const std::vector<MKL_INT> k_arr(group_count, (MKL_INT)k);
const std::vector<float> alpha_arr(group_count, alpha);
const std::vector<float> beta_arr(group_count, beta);
const std::vector<MKL_INT> lda_arr(group_count, (MKL_INT)lda);
const std::vector<MKL_INT> ldb_arr(group_count, (MKL_INT)ldb);
const std::vector<MKL_INT> ldc_arr(group_count, (MKL_INT)ldc);
const std::vector<MKL_INT> group_size(group_count, (MKL_INT)batchC); // Group size specifies number of GEMM operations per group (Which is batchC)

std::vector<const float *> a_array(batchC, nullptr);
std::vector<const float *> b_array(batchC, nullptr);
std::vector<float *> c_array(batchC, nullptr);

// This loop initializes the array pointers in the same way as the for loop
// in the normal sgemm version a few lines below
for(size_t i = 0; i < batchC; ++i) {
a_array[i] = A->data() + (i % batchA) * strideA;
b_array[i] = B->data() + (i % batchB) * strideB;
c_array[i] = C->data() + i * strideC;
}
cblas_sgemm_batch (CblasRowMajor,
&transa_arr[0],
&transb_arr[0],
&m_arr[0],
&n_arr[0],
&k_arr[0],
&alpha_arr[0],
&a_array[0],
&lda_arr[0],
&b_array[0],
&ldb_arr[0],
&beta_arr[0],
&c_array[0],
&ldc_arr[0],
group_count,
&group_size[0]);
#else
for(size_t i = 0; i < batchC; ++i) {
sgemm(transA,
transB,
Expand All @@ -149,6 +209,7 @@ void ProdBatched(marian::Tensor C,
C->data() + i * strideC,
(int)ldc);
}
#endif
#else
C; A; B; transA; transB; beta; scalar;
ABORT("You need to compile with MKL in order to use the CPU version");
Expand Down

0 comments on commit 9ae1951

Please sign in to comment.