Skip to content

Commit

Permalink
Fix multiple softmax issues (#2992)
Browse files Browse the repository at this point in the history
* use datatype for softmax problem descriptor, fix #2966

* use strides for softmax problem descriptor, fix #2813
  • Loading branch information
CAHEK7 authored May 29, 2024
1 parent 88bbde5 commit 6a46e1d
Showing 1 changed file with 30 additions and 33 deletions.
63 changes: 30 additions & 33 deletions src/softmax/problem_description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,59 +24,56 @@
*
*******************************************************************************/

#include <miopen/datatype.hpp>
#include <miopen/softmax/problem_description.hpp>
#include <miopen/names.hpp>

#include <sstream>
#include <string_view>

namespace miopen {

namespace softmax {

NetworkConfig ProblemDescription::MakeNetworkConfig() const
{
std::ostringstream ss;
std::ostringstream ss(isForward ? "sfmfwd-" : "sfmbwd-");

// all the tensors must be the same size and types
// so we can use only one set of values
const auto& desc = isForward ? xdxDesc : yDesc;
const auto [sn, sc, sh, sw] = tien<4>(desc.GetLengths());
ss << "n" << sn << "c" << sc << "h" << sh << "w" << sw;
ss << GetDataType(desc.GetType());
ss << "a" << alpha;
ss << "b" << beta;
ss << "algo" << static_cast<int>(algorithm);
ss << "mode" << static_cast<int>(mode);

ss << "sfmfwd-";
auto printStrides = [&ss](std::string_view name, const miopen::TensorDescriptor& d) {
if(d.IsPacked())
{
ss << name << "pk1";
}
else
{
const auto [n, c, h, w] = tien<4>(d.GetStrides());
ss << name << "pk0strides" << n << "x" << c << "x" << h << "x" << w;
}
};

if(isForward)
{
int n_x, c_x, h_x, w_x;
int n_y, c_y, h_y, w_y;

std::tie(n_x, c_x, h_x, w_x) = tien<4>(xdxDesc.GetLengths());
std::tie(n_y, c_y, h_y, w_y) = tien<4>(yDesc.GetLengths());

ss << "n_x" << n_x << "c_x" << c_x << "h_x" << h_x << "w_x" << w_x;
ss << "n_y" << n_y << "c_y" << c_y << "h_y" << h_y << "w_y" << w_y;

ss << "xpk" << static_cast<int>(xdxDesc.IsPacked());
ss << "ypk" << static_cast<int>(yDesc.IsPacked());
printStrides("x", xdxDesc);
printStrides("y", yDesc);
}
else
{
int n_y, c_y, h_y, w_y;
int n_dy, c_dy, h_dy, w_dy;
int n_dx, c_dx, h_dx, w_dx;

std::tie(n_y, c_y, h_y, w_y) = tien<4>(yDesc.GetLengths());
std::tie(n_dy, c_dy, h_dy, w_dy) = tien<4>(dyDesc.GetLengths());
std::tie(n_dx, c_dx, h_dx, w_dx) = tien<4>(xdxDesc.GetLengths());

ss << "n_y" << n_y << "c_y" << c_y << "h_y" << h_y << "w_y" << w_y;
ss << "n_dy" << n_dy << "c_dy" << c_dy << "h_dy" << h_dy << "w_dy" << w_dy;
ss << "n_dx" << n_dx << "c_dx" << c_dx << "h_dx" << h_dx << "w_dx" << w_dx;

ss << "ypk" << static_cast<int>(yDesc.IsPacked());
ss << "dypk" << static_cast<int>(dyDesc.IsPacked());
ss << "dxpk" << static_cast<int>(xdxDesc.IsPacked());
printStrides("y", yDesc);
printStrides("dy", dyDesc);
printStrides("dx", xdxDesc);
}

ss << "a" << alpha;
ss << "b" << beta;
ss << "algo" << static_cast<int>(algorithm);
ss << "mode" << static_cast<int>(mode);

return NetworkConfig{ss.str()};
}

Expand Down

0 comments on commit 6a46e1d

Please sign in to comment.