Skip to content

Commit

Permalink
[Mha] Mask is added for Forward pass (#3254)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vsevolod Golovko authored Sep 25, 2024
1 parent 49a69de commit 92e174c
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 10 deletions.
14 changes: 12 additions & 2 deletions include/miopen/miopen.h
Original file line number Diff line number Diff line change
Expand Up @@ -5439,9 +5439,10 @@ typedef enum

miopenTensorArgumentIsScalar = 1U << 31,

miopenTensorMhaMask = miopenTensorArgumentIsScalar | 1,
#ifdef MIOPEN_BETA_API
miopenScalarBatchnormExpAvgFactor = miopenTensorArgumentIsScalar | 1,
miopenScalarBatchnormEpsilon = miopenTensorArgumentIsScalar | 2,
miopenScalarBatchnormExpAvgFactor = miopenTensorArgumentIsScalar | 2,
miopenScalarBatchnormEpsilon = miopenTensorArgumentIsScalar | 3,
#endif
} miopenTensorArgumentId_t;

Expand Down Expand Up @@ -5473,6 +5474,15 @@ MIOPEN_EXPORT miopenStatus_t miopenCreateConvProblem(miopenProblem_t* problem,
* @return miopenStatus_t
*/

/*! @enum miopenMhaMask_t
* Different masks for Mha.
*/
typedef enum
{
miopenMhaMaskNone = 0,
miopenMhaMaskCausal = 1,
} miopenMhaMask_t;

MIOPEN_EXPORT miopenStatus_t miopenCreateMhaProblem(miopenProblem_t* problem,
miopenMhaDescriptor_t operatorDesc,
miopenProblemDirection_t direction);
Expand Down
1 change: 1 addition & 0 deletions src/api/find2_0_commons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ inline std::ostream& operator<<(std::ostream& stream, const miopenTensorArgument
case miopenTensorMhaAmaxDV: stream << "miopenTensorMhaAmaxDV"; break;
case miopenTensorMhaAmaxDS: stream << "miopenTensorMhaAmaxDS"; break;
case miopenTensorMhaBias: stream << "miopenTensorMhaBias"; break;
case miopenTensorMhaMask: stream << "miopenTensorMhaMask"; break;
case miopenTensorSoftmaxX: stream << "SoftmaxX"; break;
case miopenTensorSoftmaxY: stream << "SoftmaxY"; break;
case miopenTensorSoftmaxDX: stream << "SoftmaxDX"; break;
Expand Down
1 change: 1 addition & 0 deletions src/include/miopen/graphapi/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ inline std::string_view tensorEnumIdToStr(miopenTensorArgumentId_t id)
ENUM_CASE(miopenTensorMhaAmaxDV)
ENUM_CASE(miopenTensorMhaAmaxDS)
ENUM_CASE(miopenTensorMhaBias)
ENUM_CASE(miopenTensorMhaMask)
default: MIOPEN_THROW(miopenStatusInternalError, "unknown tensor enum id");
}
#undef ENUM_CASE
Expand Down
1 change: 1 addition & 0 deletions src/include/miopen/mha/mha.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ struct MhaDataForward
ConstData_t dropoutOffsetData;

ConstData_t biasData;
miopenMhaMask_t mask;

// output tensors
Data_t oData;
Expand Down
17 changes: 13 additions & 4 deletions src/solution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,19 @@ void Solution::RunImpl(Handle& handle,
get_input_checked(miopenTensorMhaDropoutOffset, "miopenTensorMhaDropoutOffset");

// reading bias buffer as an optional parameter
Data_t biasBuffer = nullptr;
const auto& found = inputs.find(miopenTensorMhaBias);
if(found != inputs.end())
Data_t biasBuffer = nullptr;
const auto& biasIt = inputs.find(miopenTensorMhaBias);
if(biasIt != inputs.end())
{
biasBuffer = found->second.buffer;
biasBuffer = biasIt->second.buffer;
}

// reading a mask as an optional parameter
miopenMhaMask_t mask = miopenMhaMaskNone;
const auto& maskIt = inputs.find(miopenTensorMhaMask);
if(maskIt != inputs.end())
{
mask = *(static_cast<miopenMhaMask_t*>(maskIt->second.buffer));
}

const auto invoke_ctx = [&]() -> AnyInvokeParams {
Expand All @@ -331,6 +339,7 @@ void Solution::RunImpl(Handle& handle,
dropoutSeed.buffer,
dropoutOffset.buffer,
biasBuffer,
mask,
o.buffer,
amaxO.buffer,
amaxS.buffer,
Expand Down
16 changes: 12 additions & 4 deletions test/gtest/mha_find20.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ class MhaFind20Test

const size_t numTensors = tensors.size();

auto arguments = std::make_unique<miopenTensorArgument_t[]>(numTensors);
unsigned int numberOfBuffers = numTensors + 1; // +1 for passing a scalar for mhaMask

auto arguments = std::make_unique<miopenTensorArgument_t[]>(numberOfBuffers);

std::vector<miopenTensorDescriptor_t> descVector(numTensors);

Expand All @@ -228,6 +230,11 @@ class MhaFind20Test
++i;
}

// Passing a scalar is a special case for current Find 2.0 implementation
arguments[i].id = miopenTensorMhaMask;
arguments[i].descriptor = nullptr;
arguments[i].buffer = &mhaMask;

std::vector<miopenTensorArgumentId_t> output_ids;

if(isForward)
Expand Down Expand Up @@ -264,7 +271,7 @@ class MhaFind20Test
std::cerr << "Run a solution." << std::endl;
EXPECT_EQUAL(miopenRunSolution(&handle,
solution,
numTensors,
numberOfBuffers,
arguments.get(),
workspace.ptr(),
workspace.size()),
Expand Down Expand Up @@ -381,7 +388,6 @@ class MhaFind20Test
CreateTensor(miopenTensorMhaO, test_n, test_h, test_s, test_d).InitWithRandom();

CreateTensor(miopenTensorMhaDO, test_n, test_h, test_s, test_d).InitWithRandom();
;

CreateTensor(miopenTensorMhaM, test_n, test_h, test_s, 1).InitWithFloatValue(0.0f);
CreateTensor(miopenTensorMhaZInv, test_n, test_h, test_s, 1).InitWithFloatValue(1.0f);
Expand Down Expand Up @@ -465,6 +471,7 @@ class MhaFind20Test
mhads->gpuBuffer.get(),
mhado->gpuBuffer.get(),
mhabias->gpuBuffer.get(),
mhaMask,
outputResultsMap[miopenTensorMhaO]->gpuBuffer.get(),
outputResultsMap[miopenTensorMhaAmaxO]->gpuBuffer.get(),
outputResultsMap[miopenTensorMhaAmaxS]->gpuBuffer.get(),
Expand Down Expand Up @@ -645,7 +652,8 @@ class MhaFind20Test
const unsigned int test_s = 8;
const unsigned int test_d = 16;

float scale = 1.0f;
float scale = 1.0f;
miopenMhaMask_t mhaMask = miopenMhaMaskNone;
};

TEST(GPU_TestMhaFind20_FP32, MhaForward)
Expand Down

0 comments on commit 92e174c

Please sign in to comment.