From 0c2c6eea6556b208d1a8711197efc94899e754e1 Mon Sep 17 00:00:00 2001 From: Nan Zheng <80790206+nanz-nv@users.noreply.github.com> Date: Sat, 17 Jul 2021 08:53:59 +0800 Subject: [PATCH] Added more fusion and vectorized kernel for transducer (#1125) * Added support for fused ReLU and dropout into transducer joint * Reorganized code selection path in transducer joint fwd * Added support for fused ReLU+dropout into transducer joint * Vectorize transducer loss backward with fused softmax (#3) * Nanz/transducer loss (#4) * Vectorize transducer loss backward with fused softmax * Added a predicate to avoid potential IMA * Nanz/transducer loss (#5) * Vectorize transducer loss backward with fused softmax * Added a predicate to avoid potentional IMA * Added more predicates to avoid IMAs * Updated documentations for newly added features. * Fixed a error in transducer.py --- .../csrc/transducer/transducer_joint.cpp | 32 +- .../transducer/transducer_joint_kernel.cu | 465 +++++++++++++----- .../csrc/transducer/transducer_loss_kernel.cu | 169 ++++++- .../test/transducer/test_transducer_joint.py | 78 ++- .../test/transducer/test_transducer_loss.py | 28 +- .../contrib/test/transducer/transducer_ref.py | 12 +- apex/contrib/transducer/transducer.py | 60 ++- setup.py | 3 +- 8 files changed, 662 insertions(+), 185 deletions(-) diff --git a/apex/contrib/csrc/transducer/transducer_joint.cpp b/apex/contrib/csrc/transducer/transducer_joint.cpp index 0c0029db..351e7cab 100755 --- a/apex/contrib/csrc/transducer/transducer_joint.cpp +++ b/apex/contrib/csrc/transducer/transducer_joint.cpp @@ -5,7 +5,7 @@ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) -torch::Tensor transducer_joint_cuda_forward( +std::vector transducer_joint_cuda_forward( torch::Tensor f, torch::Tensor g, torch::Tensor fLen, @@ -14,19 +14,23 @@ torch::Tensor transducer_joint_cuda_forward( int64_t packedBatch, int opt, bool packOutput, + bool relu, + bool dropout, + float dropoutProb, int tileSize); std::vector transducer_joint_cuda_backward( - torch::Tensor grad, + std::vector in, torch::Tensor fLen, torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, int maxGLen, - bool packOutput); + bool packOutput, + float scale); -torch::Tensor transducer_joint_forward( +std::vector transducer_joint_forward( torch::Tensor f, torch::Tensor g, torch::Tensor fLen, @@ -35,6 +39,9 @@ torch::Tensor transducer_joint_forward( int64_t packedBatch, int opt, bool packOutput, + bool relu, + bool dropout, + float dropoutProb, int tileSize) { CHECK_INPUT(f); CHECK_INPUT(g); @@ -51,30 +58,37 @@ torch::Tensor transducer_joint_forward( packedBatch, opt, packOutput, + relu, + dropout, + dropoutProb, tileSize); } std::vector transducer_joint_backward( - torch::Tensor grad, + std::vector in, torch::Tensor fLen, torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, int maxGLen, - bool packOutput) { - CHECK_INPUT(grad); + bool packOutput, + float scale) { + for (auto t : in){ + CHECK_INPUT(t); + } CHECK_INPUT(fLen); CHECK_INPUT(gLen); if (packOutput) CHECK_INPUT(batchOffset); return transducer_joint_cuda_backward( - grad, + in, fLen, gLen, batchOffset, maxFLen, maxGLen, - packOutput); + packOutput, + scale); } diff --git a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu index 7cb55d55..a264e865 100755 --- a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu @@ -5,6 +5,10 @@ #include #include #include +#include +#include +#include +#include "philox.h" // Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width. // width should be a power of 2 and should be less than warpSize. @@ -23,6 +27,21 @@ inline int largestPowerOfTwo(int x){ return y >> 1; } +/* +Figure out vectorization type for masks. +Similar to how PyTorch figures out acc_t here: +aten/src/ATen/AccumulateType.h +*/ +template +struct MaskVecType { }; + +template <> struct MaskVecType<1> { using type = uint8_t; }; +template <> struct MaskVecType<2> { using type = uint16_t; }; +template <> struct MaskVecType<4> { using type = uint32_t; }; + +template +using mvec_type = typename MaskVecType::type; + // Helper class to calculate pointer offset that can be shared by different flavors of kernels. // For fwd, batch offset and stride are different for packing and non-packing mode. struct OffsetCalFwd{ @@ -192,23 +211,31 @@ __global__ void transducer_joint_forward( } } -// Tiled version of the joint forward kernel -// Detail of this joint function can be found in: -// [1] Sequence Transduction with Recurrent Neural Networks. - -// f is a tensor of shape [batch, T, H] -// g is a tensor of shape [batch, U, H] -// the transducer joint does -// sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1) -// The resultant tensor is of shape [batch, T, U, H] -// Each thread is working on a tile of the shape of tileF x tileG in the result tensor. -// The input for the tile is first loaded in the register and is reused tileG and tileF times. - -// This joint function can optionally pack the output where the output tensor with a shape of -// [B, T, U, H] is packed into [B_packed, H]. -// Don't-care region (t > fLen) or (u > gLen) is removed. -// To enable packing, the starting offset for each batch need to be specified with batchOffset. -template +/* +Tiled version of the joint forward kernel +Detail of this joint function can be found in: +[1] Sequence Transduction with Recurrent Neural Networks. + +f is a tensor of shape [batch, T, H] +g is a tensor of shape [batch, U, H] +the transducer joint does +sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1) +The resultant tensor is of shape [batch, T, U, H] +Each thread is working on a tile of the shape of tileF x tileG in the result tensor. +The input for the tile is first loaded in the register and is reused tileG and tileF times. + +This joint function can optionally pack the output where the output tensor with a shape of +[B, T, U, H] is packed into [B_packed, H]. +Don't-care region (t > fLen) or (u > gLen) is removed. +To enable packing, the starting offset for each batch need to be specified with batchOffset. + +Optionally this joint function performs ReLU and/or dropout on the joint output, which is +controlled by arguments relu and dropout, respectively. philoxArgs is argument used for generating +pseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint +function is a masked operation, which is controlled by the template argument masked. In this case, +masks are saved to backward. +*/ +template __global__ void transducer_joint_tiled_forward( const scalar_t *f, const scalar_t *g, @@ -220,8 +247,14 @@ __global__ void transducer_joint_tiled_forward( int64_t hiddenSize, int64_t hiddenPerBlock, bool packOutput, - scalar_t *sum) { + bool relu, + bool dropout, + float p, + at::PhiloxCudaState philoxArgs, + scalar_t *sum, + uint8_t *mask) { + static_assert(U == 4, "U has to be 4, as random numbers are generated in batch of 4"); const int batch = blockIdx.z; const int t = blockIdx.y * tileF; @@ -239,6 +272,17 @@ __global__ void transducer_joint_tiled_forward( scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize + hOffset; scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize + hOffset; scalar_t *mySum = sum + myBatchOffset + t*strideF + u*hiddenSize + hOffset; + uint8_t *myMask = mask + myBatchOffset + t*strideF + u*hiddenSize + hOffset; + + // The following code is only needed for dropout. We try to bypass them as much as possible. + auto seeds = masked ? at::cuda::philox::unpack(philoxArgs) + : std::make_tuple(static_cast(0), static_cast(0)); + uint64_t tid = masked ? (static_cast(blockIdx.z)*gridDim.y*gridDim.x + + blockIdx.y*gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x + : 0; + Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds)); + scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0; + bool dropoutMask[U]; if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){ // register buffers for tiled input reuse @@ -256,8 +300,28 @@ __global__ void transducer_joint_tiled_forward( if (t + i < myFLen){ #pragma unroll for (int j = 0; j < tileG; ++j){ - if (u + j < myGLen) - mySum[i*strideF + j*hiddenSize + h] = fBuffer[i] + gBuffer[j]; + int idx = i*tileG + j; + if (masked and dropout and idx % U == 0){ + // For performance, generate 4 random numbers in one shot + // auto rand4 = curand_uniform4(&state); + auto rand4 = uniform4(ph()); + dropoutMask[0] = rand4.x < p; + dropoutMask[1] = rand4.y < p; + dropoutMask[2] = rand4.z < p; + dropoutMask[3] = rand4.w < p; + } + + if (u + j < myGLen){ + scalar_t out = fBuffer[i] + gBuffer[j]; + if (masked){ + // Apply ReLU here when relu is True + bool localMask = relu ? (out>0) : 1; + localMask = dropout ? localMask & dropoutMask[idx%U] : localMask; + out = dropout ? out*localMask*scale : out*localMask; + myMask[i*strideF + j*hiddenSize + h] = static_cast(localMask); + } + mySum[i*strideF + j*hiddenSize + h] = out; + } else if (packOutput == false and u + j < maxGLen) mySum[i*strideF + j*hiddenSize + h] = -1; } @@ -287,15 +351,21 @@ __global__ void transducer_joint_tiled_forward( } } -// Bwd operation (reduction) on one input tensor. Since the operation performed for the two input -// tensors are exactly the same, only one kernel is needed, and the different indexing offsets -// and strides are handled by OffsetCalBwd. +/* +Bwd operation (reduction) on one input tensor. Since the operation performed for the two input +tensors are exactly the same, only one kernel is needed, and the different indexing offsets +and strides are handled by OffsetCalBwd. + +When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a +non-packed form. -// When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a -// non-packed form. -template +When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, +and mask contains the mask information. +*/ +template __device__ void transducer_joint_single_backward( const scalar_t *grad, + const uint8_t *mask, const int *fLen, const int *gLen, const int64_t *batchOffset, @@ -304,6 +374,7 @@ __device__ void transducer_joint_single_backward( int64_t hiddenSize, bool packOutput, bool bwdFasterDim, // whether bwd on the faster moving dimension (u) + float scale, scalar_t *inGrad, int yBlockOffset=0) { @@ -331,15 +402,20 @@ __device__ void transducer_joint_single_backward( const auto myBatchOffset = offsetCal.getBatchOffset(); const auto strideX = offsetCal.getStrideX(); const auto strideY = offsetCal.getStrideY(); - scalar_t const *myGrad = grad + myBatchOffset + x*strideX + hOffset; + const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset; + const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset : nullptr; // Each warp reduces numYPerWarp "y" first acc_t warpSum = 0; auto numYPerWarp = (myYLen+numWarp-1)/numWarp; + #pragma unroll for (int warpY = 0; warpY < numYPerWarp; ++warpY){ auto y = wid*numYPerWarp + warpY; if (y < myYLen and (hOffset+lid) < hiddenSize) - warpSum += myGrad[y*strideY + lid]; + if (masked) + warpSum += static_cast(myGrad[y*strideY + lid]) * myMask[y*strideY + lid] * scale; + else + warpSum += myGrad[y*strideY + lid]; } // transpose partial sum in SMEM and reduce further using warpReduce @@ -366,13 +442,18 @@ __device__ void transducer_joint_single_backward( } } -// Actual bwd (reduction) kernel get launched. -// Call transducer_joint_single_backward twice on two input tensors. -// The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op -// uses the rest. -template +/* +Actual bwd (reduction) kernel get launched. +Call transducer_joint_single_backward twice on two input tensors. +The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op +uses the rest. +When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, +and mask contains the mask information. +*/ +template __global__ void transducer_joint_combined_backward( const scalar_t *grad, + const uint8_t *mask, const int *fLen, const int *gLen, const int64_t *batchOffset, @@ -380,11 +461,13 @@ __global__ void transducer_joint_combined_backward( int64_t maxGLen, int64_t hiddenSize, bool packOutput, + float scale, scalar_t *fGrad, scalar_t *gGrad) { if (blockIdx.y < maxFLen){ - transducer_joint_single_backward( + transducer_joint_single_backward( grad, + mask, fLen, gLen, batchOffset, @@ -393,11 +476,13 @@ __global__ void transducer_joint_combined_backward( hiddenSize, packOutput, false, + scale, fGrad); } else{ - transducer_joint_single_backward( + transducer_joint_single_backward( grad, + mask, fLen, gLen, batchOffset, @@ -406,19 +491,25 @@ __global__ void transducer_joint_combined_backward( hiddenSize, packOutput, true, + scale, gGrad, maxFLen); } } -// Vectorized version of transducer_joint_single_backward -// Doing exact same operation as transducer_joint_single_backward except the load and store are -// vectorized. -// When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a -// non-packed form. -template +/* +Vectorized version of transducer_joint_single_backward +Doing exact same operation as transducer_joint_single_backward except the load and store are +vectorized. +When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a +non-packed form. +When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, +and mask contains the mask information. +*/ +template __device__ void transducer_joint_single_vec_backward( const scalar_t *grad, + const uint8_t *mask, const int *fLen, const int *gLen, const int64_t *batchOffset, @@ -427,6 +518,7 @@ __device__ void transducer_joint_single_vec_backward( int64_t hiddenSize, bool packOutput, bool bwdFasterDim, + float scale, scalar_t *inGrad, int yBlockOffset=0){ @@ -437,6 +529,9 @@ __device__ void transducer_joint_single_vec_backward( const int lid = threadIdx.x; const int numWarp = blockDim.y; + // Figure out the vectorization type for mask + using mvec_t = mvec_type; + OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, bwdFasterDim); const auto maxXLen = offsetCal.getMaxXLen(); @@ -448,6 +543,7 @@ __device__ void transducer_joint_single_vec_backward( acc_t warpSum[V]; scalar_t inBuffer[V]; + uint8_t maskBuffer[V]; scalar_t outBuffer[V]; auto myInGradVec = reinterpret_cast(myInGrad); auto outBufferVec = reinterpret_cast(outBuffer); @@ -457,6 +553,8 @@ __device__ void transducer_joint_single_vec_backward( const auto strideX = offsetCal.getStrideX(); const auto strideY = offsetCal.getStrideY(); const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset; + const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset + :nullptr; for (int i = 0; i < V; ++i) warpSum[i] = 0; @@ -466,12 +564,22 @@ __device__ void transducer_joint_single_vec_backward( for (int warpY = 0; warpY < numYPerWarp; ++warpY){ auto y = wid*numYPerWarp + warpY; auto myGradVec = reinterpret_cast(myGrad + y*strideY); + auto myMaskVec = masked ? reinterpret_cast(myMask + y*strideY) + : nullptr; auto inBufferVec = reinterpret_cast(inBuffer); + auto maskBufferVec = reinterpret_cast(maskBuffer); if (hOffset + lid*V < hiddenSize and y < myYLen){ *inBufferVec = myGradVec[lid]; // vectorized load - #pragma unroll - for (int i = 0; i < V; ++i){ - warpSum[i] += inBuffer[i]; + if (masked){ + *maskBufferVec = myMaskVec[lid]; + #pragma unroll + for (int i = 0; i < V; ++i) + warpSum[i] += static_cast(inBuffer[i]) * maskBuffer[i] * scale; + } + else{ + #pragma unroll + for (int i = 0; i < V; ++i) + warpSum[i] += inBuffer[i]; } } } @@ -506,13 +614,18 @@ __device__ void transducer_joint_single_vec_backward( } } -// Vecotrized version of transducer_joint_combined_backward -// Call transducer_joint_single_vec_backward twice on two input tensors. -// The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op -// uses the rest. -template +/* +Vecotrized version of transducer_joint_combined_backward +Call transducer_joint_single_vec_backward twice on two input tensors. +The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op +uses the rest. +When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, +and mask contains the mask information. +*/ +template __global__ void transducer_joint_combined_vec_backward( const scalar_t *grad, + const uint8_t *mask, const int *fLen, const int *gLen, const int64_t *batchOffset, @@ -520,11 +633,13 @@ __global__ void transducer_joint_combined_vec_backward( int64_t maxGLen, int64_t hiddenSize, bool packOutput, + float scale, scalar_t *fGrad, scalar_t *gGrad) { if (blockIdx.y < maxFLen){ - transducer_joint_single_vec_backward( + transducer_joint_single_vec_backward( grad, + mask, fLen, gLen, batchOffset, @@ -533,11 +648,13 @@ __global__ void transducer_joint_combined_vec_backward( hiddenSize, packOutput, false, + scale, fGrad); } else{ - transducer_joint_single_vec_backward( + transducer_joint_single_vec_backward( grad, + mask, fLen, gLen, batchOffset, @@ -546,6 +663,7 @@ __global__ void transducer_joint_combined_vec_backward( hiddenSize, packOutput, true, + scale, gGrad, maxFLen); } @@ -554,7 +672,7 @@ __global__ void transducer_joint_combined_vec_backward( -torch::Tensor transducer_joint_cuda_forward( +std::vector transducer_joint_cuda_forward( torch::Tensor f, torch::Tensor g, torch::Tensor fLen, @@ -563,6 +681,9 @@ torch::Tensor transducer_joint_cuda_forward( int64_t packedBatch, int opt, bool packOutput, + bool relu, + bool dropout, + float dropoutProb, int tileSize){ @@ -572,17 +693,24 @@ torch::Tensor transducer_joint_cuda_forward( const auto maxFLen = f.size(1); const auto maxGLen = g.size(1); const auto hiddenSize = f.size(2); + bool masked = dropout or relu; int64_t *batchOffsetPtr = nullptr; - torch::Tensor sum; + torch::Tensor sum, mask; + auto maskOpt = tensorOpt.dtype(torch::kUInt8); if (!packOutput){ sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt); batchOffsetPtr = nullptr; + if (masked) + mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt); } else{ sum = torch::empty({packedBatch, hiddenSize}, tensorOpt); batchOffsetPtr = batchOffset.data_ptr(); + if (masked) + mask = torch::empty({packedBatch, hiddenSize}, maskOpt); } + uint8_t *maskPtr = masked ? mask.data_ptr() : nullptr; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -590,12 +718,13 @@ torch::Tensor transducer_joint_cuda_forward( // Simple heuristics const int numThread = std::min(128, (static_cast(hiddenSize)+C10_WARP_SIZE-1) / C10_WARP_SIZE * C10_WARP_SIZE); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] { - if (opt == 0){ - // vanilla kernel - const int threads = numThread; - const dim3 blocks(maxGLen, maxFLen, batchSize); + + if (opt == 0){ + // vanilla kernel + const int threads = numThread; + const dim3 blocks(maxGLen, maxFLen, batchSize); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] { transducer_joint_forward <<>>( f.data_ptr(), @@ -608,54 +737,111 @@ torch::Tensor transducer_joint_cuda_forward( hiddenSize, packOutput, sum.data_ptr()); - } - if (opt == 1){ - // tiled version. For simplicity, assume tileF == tileG, even though the kernel can - // support more general cases. - const int threads = numThread; - const int hiddenPerBlock = numThread; - const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock; - const dim3 blocks( (maxGLen+tileSize-1)/tileSize * hiddenBlock, - (maxFLen+tileSize-1)/tileSize, - batchSize); - - TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4, + })); + } + if (opt == 1){ + // tiled version. For simplicity, assume tileF == tileG, even though the kernel can + // support more general cases. + const int threads = numThread; + const int hiddenPerBlock = numThread; + const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock; + const dim3 blocks( (maxGLen+tileSize-1)/tileSize * hiddenBlock, + (maxFLen+tileSize-1)/tileSize, + batchSize); + + TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4, "Expected tileSize to be in [1, 2, 4], but got ", tileSize); - switch (tileSize) { - #define LAUNCH_TRANSDUCER_JOINT_TILED_FORWARD(tile) case tile:\ - transducer_joint_tiled_forward\ - <<>>(\ - f.data_ptr(),\ - g.data_ptr(),\ - fLen.data_ptr(),\ - gLen.data_ptr(),\ - batchOffsetPtr,\ - maxFLen,\ - maxGLen,\ - hiddenSize,\ - hiddenPerBlock,\ - packOutput,\ - sum.data_ptr());\ - break; - LAUNCH_TRANSDUCER_JOINT_TILED_FORWARD(1); - LAUNCH_TRANSDUCER_JOINT_TILED_FORWARD(2); - LAUNCH_TRANSDUCER_JOINT_TILED_FORWARD(4); - } + at::PhiloxCudaState rng_engine_inputs; + if (masked){ + // set up PRG when the input is masked. rng_engine_inputs will be used as a space filler + // for non-masked calls. + // Therefore no need to initialize. + c10::optional gen_; + auto gen = at::get_generator_or_default(gen_, + at::cuda::detail::getDefaultCUDAGenerator()); + // counterOffset records how many cuRAND calls each thread makes. For a tiled kernel, + // each thread processes tileF * tileG output elements. + int64_t counterOffset = tileSize * tileSize; + { + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(counterOffset); + } } - })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] { + void(*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*, + int64_t, int64_t, int64_t, int64_t, bool, bool, bool, float, + at::PhiloxCudaState, scalar_t*, uint8_t*); + if (masked){ + switch (tileSize){ + case 2: + kernel = &transducer_joint_tiled_forward; + break; + case 4: + kernel = &transducer_joint_tiled_forward; + break; + } + } + else{ + switch (tileSize){ + case 1: + kernel = &transducer_joint_tiled_forward; + break; + case 2: + kernel = &transducer_joint_tiled_forward; + break; + case 4: + kernel = &transducer_joint_tiled_forward; + break; + } + } + + kernel<<>>( + f.data_ptr(), + g.data_ptr(), + fLen.data_ptr(), + gLen.data_ptr(), + batchOffsetPtr, + maxFLen, + maxGLen, + hiddenSize, + hiddenPerBlock, + packOutput, + relu, + dropout, + 1.0f - dropoutProb, + rng_engine_inputs, + sum.data_ptr(), + maskPtr); + })); + } + THCudaCheck(cudaGetLastError()); - return sum; + if (masked) + return {sum, mask}; + else + return {sum}; } std::vector transducer_joint_cuda_backward( - torch::Tensor grad, + std::vector in, torch::Tensor fLen, torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, int maxGLen, - bool packOutput){ + bool packOutput, + float scale){ + + auto grad = in[0]; + bool masked = (in.size() == 2); + uint8_t *maskPtr = masked ? in[1].data_ptr() : nullptr; auto tensorOpt = grad.options(); auto dtype = grad.scalar_type(); @@ -709,35 +895,76 @@ std::vector transducer_joint_cuda_backward( const dim3 blocks( (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor), maxFLen+maxGLen, batchSize); - transducer_joint_combined_vec_backward - - <<>>( - gradPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - fGradPtr, - gGradPtr); + if (masked){ + transducer_joint_combined_vec_backward + + <<>>( + gradPtr, + maskPtr, + fLenPtr, + gLenPtr, + batchOffsetPtr, + maxFLen, + maxGLen, + hiddenSize, + packOutput, + scale, + fGradPtr, + gGradPtr); + } + else{ + transducer_joint_combined_vec_backward + + <<>>( + gradPtr, + maskPtr, + fLenPtr, + gLenPtr, + batchOffsetPtr, + maxFLen, + maxGLen, + hiddenSize, + packOutput, + scale, + fGradPtr, + gGradPtr); + } } else{ const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE, maxFLen + maxGLen, batchSize); - transducer_joint_combined_backward - <<>>( - gradPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - fGradPtr, - gGradPtr); + if (masked){ + transducer_joint_combined_backward + <<>>( + gradPtr, + maskPtr, + fLenPtr, + gLenPtr, + batchOffsetPtr, + maxFLen, + maxGLen, + hiddenSize, + packOutput, + scale, + fGradPtr, + gGradPtr); + } + else{ + transducer_joint_combined_backward + <<>>( + gradPtr, + maskPtr, + fLenPtr, + gLenPtr, + batchOffsetPtr, + maxFLen, + maxGLen, + hiddenSize, + packOutput, + scale, + fGradPtr, + gGradPtr); + } } })); diff --git a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu index 694b2a33..1ebbd3ae 100755 --- a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu @@ -408,7 +408,7 @@ __global__ void transducer_loss_fused_backward( : batch * maxFLen * maxGLen; const int64_t myStrideT = packedInput ? myGLen : maxGLen; - __shared__ acc_t commonFactor, myBetaTU; + __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared; auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; if (t < myFLen and u < myGLen){ @@ -421,6 +421,9 @@ __global__ void transducer_loss_fused_backward( if (tid == 0){ commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0]; myBetaTU = myBeta[t*maxGLen + u]; + myBetaTUp1 = myBeta[t*maxGLen + u + 1]; + myBetaTp1U = myBeta[(t+1)*maxGLen + u]; + myLabelShared = myLabel[u]; } __syncthreads(); @@ -429,14 +432,14 @@ __global__ void transducer_loss_fused_backward( // Do the update acc_t grad = commonFactor + myX[h]; // loss = -ln(Pr(y*|x)) acc_t myGrad = std::exp(grad + myBetaTU); - if (u != myGLen - 1 and h == myLabel[u]){ - myGrad -= std::exp(grad + myBeta[t*maxGLen + u + 1]); + if (u != myGLen - 1 and h == myLabelShared){ + myGrad -= std::exp(grad + myBetaTUp1); } else if (h == blankIdx){ if (t == myFLen - 1 and u == myGLen - 1) myGrad -= std::exp(grad); else if (t != myFLen - 1) - myGrad -= std::exp(grad + myBeta[(t+1)*maxGLen + u]); + myGrad -= std::exp(grad + myBetaTp1U); } myXGrad[h] = myGrad; } @@ -450,6 +453,104 @@ __global__ void transducer_loss_fused_backward( } +// Vectorized version of fused transudcer loss backward operation. +// Detail of this loss function can be found in: +// [1] Sequence Transduction with Recurrent Neural Networks. +// The bwd op of the preceding softmax layer is fused in this kernel. +// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time + +// To support the packed input, the starting offsets for each batch need to be specified with +// batchOffset. +template +__global__ void transducer_loss_fused_vec_backward( + const scalar_t* x, + const scalar_t* lossGrad, + const int* audLen, + const int* txtLen, + const int* label, + const acc_t* alpha, + const acc_t* beta, + const int64_t* batchOffset, + int64_t dictSize, + int64_t blankIdx, + int64_t maxFLen, + int64_t maxGLen, + bool packedInput, + scalar_t* xGrad) { + + const int tid = threadIdx.x; + const int u = blockIdx.x; + const int t = blockIdx.y; + const int batch = blockIdx.z; + const int64_t myFLen = audLen[batch]; + const int64_t myGLen = txtLen[batch] + 1; + const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) + : batch * maxFLen * maxGLen; + const int64_t myStrideT = packedInput ? myGLen : maxGLen; + + __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared; + auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; + auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; + auto myAlpha = alpha + batch*maxFLen*maxGLen; + auto myBeta = beta + batch*maxFLen*maxGLen; + auto myLabel = label + batch*(maxGLen-1); + + // Variabels for vectorization + scalar_t myXBuffer[V], myXGradBuffer[V]; + auto myXVec = reinterpret_cast(myX); + auto myXGradVec = reinterpret_cast(myXGrad); + auto myXBufferVec = reinterpret_cast(myXBuffer); + auto myXGradBufferVec = reinterpret_cast(myXGradBuffer); + if (t < myFLen and u < myGLen){ + // load and store shared variables in SMEM + if (tid == 0){ + commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0]; + myBetaTU = myBeta[t*maxGLen + u]; + if (t != myFLen - 1) + myBetaTp1U = myBeta[(t+1)*maxGLen + u]; + if (u != myGLen - 1){ + myBetaTUp1 = myBeta[t*maxGLen + u + 1]; + myLabelShared = myLabel[u]; + } + } + + __syncthreads(); + + #pragma unroll + for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){ + // Load myX in a vector form + *myXBufferVec = myXVec[h0/V]; + // Do the update for a vector of input + #pragma unroll + for (int i = 0; i < V; ++i){ + auto h = h0 + i; + acc_t grad = commonFactor + myXBuffer[i]; // loss = -ln(Pr(y*|x)) + acc_t myGrad = std::exp(grad + myBetaTU); + if (u != myGLen - 1 and h == myLabelShared){ + myGrad -= std::exp(grad + myBetaTUp1); + } + else if (h == blankIdx){ + if (t == myFLen - 1 and u == myGLen - 1) + myGrad -= std::exp(grad); + else if (t != myFLen - 1) + myGrad -= std::exp(grad + myBetaTp1U); + } + myXGradBuffer[i] = myGrad; + } + + // Store myXGrad in a vector form + myXGradVec[h0/V] = *myXGradBufferVec; + + } + } + else if (!packedInput){ + // In non-pack mode, need to make sure the gradients for don't-care regions are zero. + for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){ + myXGradVec[h0/V] = 0; + } + } +} + std::vector transducer_loss_cuda_forward( torch::Tensor x, @@ -586,23 +687,51 @@ torch::Tensor transducer_loss_cuda_backward( const dim3 blocks(maxGLen, maxFLen, batchSize); AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] { + using vec_t = uint64_t; using acc_t = at::acc_type; - transducer_loss_fused_backward<<>>( - x.data_ptr(), - lossGrad.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), - label.data_ptr(), - alpha.data_ptr(), - beta.data_ptr(), - batchOffsetPtr, - dictSize, - blankIdx, - maxFLen, - maxGLen, - packedInput, - xGrad.data_ptr()); - + constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t); + constexpr int vecAlignment = std::alignment_of::value; + // if all input and output tensors meet the alignment requirement + bool memAlign = reinterpret_cast(x.data_ptr()) % vecAlignment == 0 + and reinterpret_cast(xGrad.data_ptr()) + % vecAlignment == 0; + + if (vectFactor > 1 and dictSize%vectFactor == 0 and memAlign){ + transducer_loss_fused_vec_backward + <<>>( + x.data_ptr(), + lossGrad.data_ptr(), + audLen.data_ptr(), + txtLen.data_ptr(), + label.data_ptr(), + alpha.data_ptr(), + beta.data_ptr(), + batchOffsetPtr, + dictSize, + blankIdx, + maxFLen, + maxGLen, + packedInput, + xGrad.data_ptr()); + } + else{ + transducer_loss_fused_backward<<>>( + x.data_ptr(), + lossGrad.data_ptr(), + audLen.data_ptr(), + txtLen.data_ptr(), + label.data_ptr(), + alpha.data_ptr(), + beta.data_ptr(), + batchOffsetPtr, + dictSize, + blankIdx, + maxFLen, + maxGLen, + packedInput, + xGrad.data_ptr()); + + } })); } else{ diff --git a/apex/contrib/test/transducer/test_transducer_joint.py b/apex/contrib/test/transducer/test_transducer_joint.py index 619a6f33..c1c8dd1e 100755 --- a/apex/contrib/test/transducer/test_transducer_joint.py +++ b/apex/contrib/test/transducer/test_transducer_joint.py @@ -28,6 +28,7 @@ def gen_input(self, for_vector_kernel): self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device) self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max + self.dropout_prob = 0.5 # Make sure gradients from out-of-bound locations are zero. This should be guaranteed by # the loss function @@ -49,30 +50,38 @@ def _pack(self, x, f_len, g_len): batch_offset = torch.cumsum(f_len * g_len, dim=0) return x_packed + def _unpack(self, x, f_len, g_len): + batch_offset = torch.cumsum(f_len * g_len, dim=0) + x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8) + B = self.h_grad.size(0) + H = self.h_grad.size(-1) + for b in range(B): + my_batch_offset = 0 if b == 0 else batch_offset[b-1] + my_f_len = f_len[b] + my_g_len = g_len[b] + for t in range(my_f_len): + x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len : + my_batch_offset + t*my_g_len + my_g_len] + return x_unpacked - def run_transducer_joint(self, for_vector_kernel, pack_output): + def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout): self.gen_input(for_vector_kernel=for_vector_kernel) # Generate reference f_ref = self.f_tst.data.clone() g_ref = self.g_tst.data.clone() f_ref.requires_grad = True g_ref.requires_grad = True - - h_ref, f_grad_ref, g_grad_ref \ - = transducer_ref.transducer_joint_reference(f=f_ref, - g=g_ref, - h_grad=self.h_grad, - f_len=self.f_len, - g_len=self.g_len, - pack_output=pack_output) - my_joint= TransducerJoint(pack_output=pack_output) + my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout, + dropout_prob=self.dropout_prob, probe_mask=True) if not pack_output: h_tst = my_joint( f=self.f_tst, g=self.g_tst, f_len=self.f_len, g_len=self.g_len) h_tst.backward(self.h_grad) + if dropout: + mask = my_joint.mask_probe[0] else: batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0) h_tst = my_joint( f=self.f_tst, @@ -82,6 +91,22 @@ def run_transducer_joint(self, for_vector_kernel, pack_output): batch_offset=batch_offset, packed_batch=batch_offset[-1]) h_tst.backward(self.h_grad_packed) + if dropout: + mask_packed = my_joint.mask_probe[0] + mask = self._unpack(mask_packed, self.f_len, self.g_len) + + # reference + h_ref, f_grad_ref, g_grad_ref \ + = transducer_ref.transducer_joint_reference(f=f_ref, + g=g_ref, + h_grad=self.h_grad, + f_len=self.f_len, + g_len=self.g_len, + pack_output=pack_output, + relu=relu, + dropout=dropout, + dropout_prob=self.dropout_prob, + mask=mask if dropout else None) f_grad_tst = self.f_tst.grad g_grad_tst = self.g_tst.grad @@ -91,16 +116,41 @@ def run_transducer_joint(self, for_vector_kernel, pack_output): self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4)) def test_transducer_joint(self): - self.run_transducer_joint(for_vector_kernel=False, pack_output=False) + self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False) def test_transducer_joint_vec(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=False) + self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False) def test_transducer_joint_pack(self): - self.run_transducer_joint(for_vector_kernel=False, pack_output=True) + self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False) def test_transducer_joint_vec_pack(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True) + self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False) + + def test_transducer_joint_relu(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) + + def test_transducer_joint_vec_relu(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False) + + def test_transducer_joint_pack_relu(self): + self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False) + + def test_transducer_joint_vec_pack_relu(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) + + def test_transducer_joint_relu_dropout(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) + + def test_transducer_joint_vec_relu_dropout(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True) + + def test_transducer_joint_pack_relu_dropout(self): + self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True) + + def test_transducer_joint_vec_pack_relu_dropout(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) + if __name__ == '__main__': diff --git a/apex/contrib/test/transducer/test_transducer_loss.py b/apex/contrib/test/transducer/test_transducer_loss.py index 157bcca9..82f5bd33 100755 --- a/apex/contrib/test/transducer/test_transducer_loss.py +++ b/apex/contrib/test/transducer/test_transducer_loss.py @@ -8,13 +8,13 @@ def setUp(self, seed=1234): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - def gen_input(self, scalar_t): + def gen_input(self, scalar_t, for_vector_kernel): self.B = 5 T_min = 23 T_max = 51 U_min = 12 U_max = 25 - V = 16 + V = 16 if for_vector_kernel else 14 self.blank_idx = V - 1 device = "cuda" @@ -61,8 +61,8 @@ def _unpack(self, x): x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u] return x_unpacked - def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input): - self.gen_input(scalar_t) + def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel): + self.gen_input(scalar_t, for_vector_kernel) my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward, packed_input=packed_input) if not packed_input: @@ -90,28 +90,40 @@ def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input): def test_transducer_loss_fp32(self): loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float32, fuse_softmax_backward=False, - packed_input=False) + packed_input=False, + for_vector_kernel=False) self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5)) def test_transducer_loss_fp16(self): loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, fuse_softmax_backward=False, - packed_input=False) + packed_input=False, + for_vector_kernel=False) self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) def test_transducer_loss_fp16_backward_fusion(self): loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, fuse_softmax_backward=True, - packed_input=False) + packed_input=False, + for_vector_kernel=False) self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) def test_transducer_loss_fp16_backward_fusion_packed(self): loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, fuse_softmax_backward=True, - packed_input=True) + packed_input=True, + for_vector_kernel=False) + self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) + self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) + + def test_transducer_loss_fp16_backward_fusion_packed_vec(self): + loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, + fuse_softmax_backward=True, + packed_input=True, + for_vector_kernel=True) self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) diff --git a/apex/contrib/test/transducer/transducer_ref.py b/apex/contrib/test/transducer/transducer_ref.py index 58232e2b..de342798 100755 --- a/apex/contrib/test/transducer/transducer_ref.py +++ b/apex/contrib/test/transducer/transducer_ref.py @@ -76,12 +76,21 @@ def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx): return alpha, beta, x.grad, loss -def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output): +def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dropout, + dropout_prob=0, mask=None): + if dropout and mask == None: + raise NotImplementedError("mask needs to supplied to test dropout.") B, T, H = f.size() U = g.size(1) f_expand = f.unsqueeze(dim=2) g_expand = g.unsqueeze(dim=1) h = f_expand + g_expand + if relu: + h = torch.nn.functional.relu(h) + if dropout: + h *= mask + scale = 1/(1-dropout_prob) + h *= scale h.backward(h_grad) if pack_output == False: @@ -90,6 +99,7 @@ def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output): for b in range(B): h[b, f_len[b]:] = -1 h[b, :, g_len[b]:] = -1 + return h, f.grad, g.grad # packing diff --git a/apex/contrib/transducer/transducer.py b/apex/contrib/transducer/transducer.py index 42990f96..78439627 100755 --- a/apex/contrib/transducer/transducer.py +++ b/apex/contrib/transducer/transducer.py @@ -10,18 +10,34 @@ class TransducerJoint(torch.nn.Module): Arguments: pack_output (bool, optional): whether to pack the output in a compact form with don't-care data being removed. (default: False) + relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1 + (default: False) + dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1 + (default: False) opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm. (default: 1) fwd_tile_size (int, optional): tile size used in forward operation. This argument will be ignored if opt != 1. (default: 4) + dropout_prob (float, optional): dropout probability. (default: 0.0) + probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout + operation. When this argument is set to True, the mask can be accessed through + self.mask_probe. (default: false) """ - def __init__(self, pack_output=False, opt=1, fwd_tile_size=4): + def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4, + dropout_prob=0, probe_mask=False): super(TransducerJoint, self).__init__() self.pack_output = pack_output + self.relu = relu + self.dropout = dropout + self.dropout_prob = dropout_prob self.opt = opt self.fwd_tile_size = fwd_tile_size self.dummy_batch_offset = torch.empty(0) + masked = self.relu or self.dropout + self.mask_probe = [] if masked and probe_mask else None + if masked and opt != 1: + raise NotImplementedError("ReLU and dropout fusion is only supported with opt=1") def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0): @@ -43,8 +59,10 @@ def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0): my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset if self.pack_output and (batch_offset is None or packed_batch == 0): raise Exception("Please specify batch_offset and packed_batch when packing is enabled") - return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, my_batch_offset, - packed_batch, self.opt, self.fwd_tile_size) + dropout = self.dropout and self.training # only dropout for training + return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout, + my_batch_offset, packed_batch, self.opt, + self.fwd_tile_size, self.dropout_prob, self.mask_probe) class TransducerLoss(torch.nn.Module): @@ -139,23 +157,39 @@ def backward(ctx, loss_grad): class TransducerJointFunc(torch.autograd.Function): @staticmethod - def forward(ctx, f, g, f_len, g_len, pack_output, batch_offset, packed_batch, opt, - fwd_tile_size): + def forward(ctx, f, g, f_len, g_len, pack_output, relu, dropout, batch_offset, packed_batch, + opt, fwd_tile_size, dropout_prob, mask_probe): h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt, - pack_output, fwd_tile_size) - ctx.save_for_backward(f_len, g_len, batch_offset) + pack_output, relu, dropout, dropout_prob, fwd_tile_size) + masked = relu or dropout + if masked: + ctx.save_for_backward(h[1], f_len, g_len, batch_offset) + if mask_probe is not None: + mask_probe.append(h[1]) + else: + ctx.save_for_backward(f_len, g_len, batch_offset) + ctx.pack_output = pack_output + ctx.masked = relu or dropout ctx.max_f_len = f.size(1) ctx.max_g_len = g.size(1) - return h + ctx.scale = 1 / (1-dropout_prob) if dropout and dropout_prob != 1 else 1 + return h[0] @staticmethod def backward(ctx, loss_grad): - f_len, g_len, batch_offset = ctx.saved_tensors - f_grad, g_grad = transducer_joint_cuda.backward(loss_grad, f_len, g_len, batch_offset, - ctx.max_f_len, ctx.max_g_len, - ctx.pack_output) + if ctx.masked: + mask, f_len, g_len, batch_offset = ctx.saved_tensors + inp = [loss_grad, mask] + else: + f_len, g_len, batch_offset = ctx.saved_tensors + inp = [loss_grad] + + f_grad, g_grad = transducer_joint_cuda.backward( inp, f_len, g_len, batch_offset, + ctx.max_f_len, ctx.max_g_len, + ctx.pack_output, ctx.scale) - return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None + return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None, None, \ + None, None, None diff --git a/setup.py b/setup.py index 0b932cd6..38a8674a 100644 --- a/setup.py +++ b/setup.py @@ -512,7 +512,8 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): 'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'], include_dirs=[os.path.join(this_dir, 'csrc')], extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':['-O3'] + version_dependent_macros})) + 'nvcc':['-O3', + '-I./apex/contrib/csrc/multihead_attn/'] + version_dependent_macros})) ext_modules.append( CUDAExtension(name='transducer_loss_cuda', sources=['apex/contrib/csrc/transducer/transducer_loss.cpp',