Skip to content

Commit

Permalink
【Error Message No.18】 part 1 of 'paddle/cinn/frontend/op_mappers/*' (P…
Browse files Browse the repository at this point in the history
…addlePaddle#64348)

* fix part

* fix part

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
enkilee authored and chen2016013 committed May 26, 2024
1 parent 96150cc commit 6b4d369
Show file tree
Hide file tree
Showing 15 changed files with 542 additions and 214 deletions.
32 changes: 24 additions & 8 deletions paddle/cinn/frontend/op_mappers/paddle/arg_min_max.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"
#include "paddle/cinn/frontend/var_type_utils.h"
#include "paddle/common/enforce.h"

namespace cinn {
namespace frontend {
Expand Down Expand Up @@ -47,24 +48,39 @@ Variable ArgImpl<ArgType::ArgMin>(NetBuilder* builder,
template <ArgType type>
void ArgOpMapperHelper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Argmax/Argmin op must be 1."));
auto x_name = op_desc.Input("X").front();

CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument(
"The output of Argmax/Argmin op must be 1."));
auto out_name = op_desc.Output("Out").front();

auto x = ctx.GetVar(x_name);
auto axis = utils::GetAttrOrDefault<int64_t>(op_desc, "axis", -1);
CHECK(op_desc.HasAttr("axis"))
<< "Argmax/Argmin op should has attribute \"axis\"! Please check.";
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("axis"),
true,
phi::errors::InvalidArgument("Argmax/Argmin op should has attribute "
"\"axis\"! Please check."));

auto keepdims = utils::GetAttrOrDefault<bool>(op_desc, "keepdims", false);
CHECK(op_desc.HasAttr("keepdims"))
<< "Argmax/Argmin op should has attribute \"keepdims\"! Please check.";
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("keepdims"),
true,
phi::errors::InvalidArgument("Argmax/Argmin op should has attribute"
" \"keepdims\"! Please check."));

auto flatten = utils::GetAttrOrDefault<bool>(op_desc, "flatten", false);
CHECK(op_desc.HasAttr("flatten"))
<< "Argmax/Argmin op should has attribute \"flatten\"! Please check.";
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("flatten"),
true,
phi::errors::InvalidArgument("Argmax/Argmin op should has attribute"
" \"flatten\"! Please check."));

auto dtype = utils::GetPaddleDtype(
op_desc, "dtype", paddle::cpp::VarDescAPI::Type::INT64);
Expand Down
17 changes: 13 additions & 4 deletions paddle/cinn/frontend/op_mappers/paddle/argsort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,29 @@
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"
#include "paddle/cinn/utils/string.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void ArgsortOpMapper(const paddle::cpp::OpDesc& op_desc,
const cinn::frontend::OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Argmax/Argmin op must be 1."));
auto x_name = op_desc.Input("X").front();

CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument(
"The output of Argmax/Argmin op must be 1."));
auto out_name = op_desc.Output("Out").front();

CHECK_EQ(op_desc.Output("Indices").size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output("Indices").size(),
1UL,
phi::errors::InvalidArgument(
"The output of Argmax/Argmin op must be 1."));
auto indices_name = op_desc.Output("Indices").front();

auto is_ascend =
Expand Down
17 changes: 13 additions & 4 deletions paddle/cinn/frontend/op_mappers/paddle/atan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,27 @@
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"
#include "paddle/cinn/utils/string.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void Atan2OpMapper(const paddle::cpp::OpDesc& op_desc,
const cinn::frontend::OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X1").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X1").size(),
1UL,
phi::errors::InvalidArgument("The input of Atan2 op must be 1."));
auto x1_name = op_desc.Input("X1").front();
CHECK_EQ(op_desc.Input("X2").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X2").size(),
1UL,
phi::errors::InvalidArgument("The input of Atan2 op must be 1."));
auto x2_name = op_desc.Input("X2").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Atan2 op must be 1."));
auto out_name = op_desc.Output("Out").front();

auto x1 = ctx.GetVar(x1_name);
Expand Down
61 changes: 46 additions & 15 deletions paddle/cinn/frontend/op_mappers/paddle/batchnorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {
Expand All @@ -29,7 +29,10 @@ void BatchNormOpMapper(const paddle::cpp::OpDesc& op_desc,
<< op_desc.Type();
return;
}
CHECK_EQ(op_desc.Output(pd_param_name).size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output(pd_param_name).size(),
1UL,
phi::errors::InvalidArgument("The output of batch_norm op must be 1."));
auto output_name = op_desc.Output(pd_param_name).front();

VLOG(4) << "The " << op_desc.Type() << "'s output " << pd_param_name
Expand All @@ -39,15 +42,30 @@ void BatchNormOpMapper(const paddle::cpp::OpDesc& op_desc,
ctx.AddVarModelToProgram(output_name, out->id, can_inplace);
};

CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of batch_norm op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Input("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("Scale").size(),
1UL,
phi::errors::InvalidArgument("The input of batch_norm op must be 1."));
auto scale_name = op_desc.Input("Scale").front();
CHECK_EQ(op_desc.Input("Bias").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("Bias").size(),
1UL,
phi::errors::InvalidArgument("The input of batch_norm op must be 1."));
auto bias_name = op_desc.Input("Bias").front();
CHECK_EQ(op_desc.Input("Mean").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("Mean").size(),
1UL,
phi::errors::InvalidArgument("The input of batch_norm op must be 1."));
auto mean_name = op_desc.Input("Mean").front();
CHECK_EQ(op_desc.Input("Variance").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("Variance").size(),
1UL,
phi::errors::InvalidArgument("The input of batch_norm op must be 1."));
auto variance_name = op_desc.Input("Variance").front();

auto epsilon = utils::GetAttrOrDefault<float>(op_desc, "epsilon", 1e-5f);
Expand Down Expand Up @@ -105,8 +123,11 @@ void BatchNormOpMapper(const paddle::cpp::OpDesc& op_desc,
add_output("VarianceOut", variance_out, true);
} else {
VLOG(4) << "Invoke batch_norm OpMapper with train mode";
CHECK_EQ(outs.size(), 5U)
<< "batch_norm in train mode should only has 5 output! Please check.";
PADDLE_ENFORCE_EQ(outs.size(),
5U,
phi::errors::InvalidArgument(
"batch_norm in train mode should only has 5 output! "
"Please check."));

add_output("Y", outs[0]);
add_output("SavedMean", outs[1]);
Expand All @@ -122,7 +143,10 @@ void BatchNormGradOpMapper(const paddle::cpp::OpDesc& op_desc,
std::unordered_map<std::string, std::string> input_names_map;
auto get_input_var =
[&op_desc, &ctx, &input_names_map](const std::string& op_name) {
CHECK_EQ(op_desc.Input(op_name).size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Input(op_name).size(),
1UL,
phi::errors::InvalidArgument(
"The input of batch_norm_grad op must be 1."));
auto var_name = op_desc.Input(op_name).front();
input_names_map.emplace(op_name, var_name);
return ctx.GetVar(var_name);
Expand All @@ -132,12 +156,17 @@ void BatchNormGradOpMapper(const paddle::cpp::OpDesc& op_desc,
auto get_output_name =
[&op_desc, &output_names_map](const std::string& op_name) -> std::string {
if (op_desc.Output(op_name).empty()) {
CHECK_NE(op_name, paddle::GradVarName("X"))
<< "The input X should not empty.";
PADDLE_ENFORCE_NE(
op_name,
paddle::GradVarName("X"),
phi::errors::InvalidArgument("The input X should not empty."));
return "";
}

CHECK_EQ(op_desc.Output(op_name).size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output(op_name).size(),
1UL,
phi::errors::InvalidArgument(
"The output of batch_norm_grad op must be 1."));
auto var_name = op_desc.Output(op_name).front();
output_names_map.emplace(op_name, var_name);
return var_name;
Expand Down Expand Up @@ -174,8 +203,10 @@ void BatchNormGradOpMapper(const paddle::cpp::OpDesc& op_desc,
// batch norm grad, output(grad_x, grad_scale, grad_bias)
auto outs = ctx.Builder()->BatchNormGrad(
dy, x, scale, saved_mean, saved_variance, epsilon, data_layout);
CHECK_EQ(outs.size(), 3ul)
<< "batch_norm_grad APIs should return 3 Variable!";
PADDLE_ENFORCE_EQ(outs.size(),
3ul,
phi::errors::InvalidArgument(
"batch_norm_grad APIs should return 3 Variable!"));

for (int i = 0; i < outs.size(); i++) {
if (output_names[i].empty()) {
Expand Down
39 changes: 24 additions & 15 deletions paddle/cinn/frontend/op_mappers/paddle/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,34 @@

#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

#define BINARY_OPMAPPER_FUNCTION(OP_NAME) \
void OP_NAME##OpMapper(const paddle::cpp::OpDesc& op_desc, \
const OpMapperContext& ctx) { \
CHECK_EQ(op_desc.Input("X").size(), 1UL); \
auto x_name = op_desc.Input("X").front(); \
CHECK_EQ(op_desc.Input("Y").size(), 1UL); \
auto y_name = op_desc.Input("Y").front(); \
CHECK_EQ(op_desc.Output("Out").size(), 1UL); \
auto out_name = op_desc.Output("Out").front(); \
auto x = ctx.GetVar(x_name); \
auto y = ctx.GetVar(y_name); \
auto out = ctx.Builder()->OP_NAME(x, y); \
ctx.AddVar(out_name, out); \
ctx.AddVarModelToProgram(out_name, out->id); \
#define BINARY_OPMAPPER_FUNCTION(OP_NAME) \
void OP_NAME##OpMapper(const paddle::cpp::OpDesc& op_desc, \
const OpMapperContext& ctx) { \
PADDLE_ENFORCE_EQ( \
op_desc.Input("X").size(), \
1UL, \
phi::errors::InvalidArgument("The input of op must be 1.")); \
auto x_name = op_desc.Input("X").front(); \
PADDLE_ENFORCE_EQ( \
op_desc.Input("Y").size(), \
1UL, \
phi::errors::InvalidArgument("The input of op must be 1.")); \
auto y_name = op_desc.Input("Y").front(); \
PADDLE_ENFORCE_EQ( \
op_desc.Output("Out").size(), \
1UL, \
phi::errors::InvalidArgument("The output of op must be 1.")); \
auto out_name = op_desc.Output("Out").front(); \
auto x = ctx.GetVar(x_name); \
auto y = ctx.GetVar(y_name); \
auto out = ctx.Builder()->OP_NAME(x, y); \
ctx.AddVar(out_name, out); \
ctx.AddVarModelToProgram(out_name, out->id); \
}

BINARY_OPMAPPER_FUNCTION(LogicalAnd)
Expand Down
12 changes: 9 additions & 3 deletions paddle/cinn/frontend/op_mappers/paddle/cholesky.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@

#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void CholeskyOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of cholesky op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of cholesky op must be 1."));
auto out_name = op_desc.Output("Out").front();

auto upper = utils::GetAttrOrDefault<bool>(op_desc, "upper", false);
Expand Down
Loading

0 comments on commit 6b4d369

Please sign in to comment.