Skip to content

Commit

Permalink
Revert "【Error Message No. 11】Update error massage of paddle\fluid\in…
Browse files Browse the repository at this point in the history
…ference\tensorrt\convert\test_custom_op_plugin.h (PaddlePaddle#67429)"

This reverts commit 6d660ac.
  • Loading branch information
YuanRisheng committed Aug 22, 2024
1 parent 2a2e55e commit 31dfe0a
Showing 1 changed file with 28 additions and 161 deletions.
189 changes: 28 additions & 161 deletions paddle/fluid/inference/tensorrt/convert/test_custom_op_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,196 +139,63 @@ class custom_op_plugin_creator : public nvinfer1::IPluginCreator {
nvinfer1::IPluginV2* createPlugin(
const char* name,
const nvinfer1::PluginFieldCollection* fc) noexcept override {
PADDLE_ENFORCE_EQ(
fc->nbFields,
7,
phi::errors::InvalidArgument("fc->nbFields is invalid. "
"Expected 7, but received %d.",
fc->nbFields));
CHECK_EQ(fc->nbFields, 7);
// float_attr
auto attr_field = (fc->fields)[0];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kFLOAT32,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kFLOAT32"));
PADDLE_ENFORCE_EQ(attr_field.length,
1,
phi::errors::InvalidArgument(
"The length of attr_field is invalid. "
"Expected attr_field.length=1, but received %d.",
attr_field.length));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kFLOAT32);
CHECK_EQ(attr_field.length, 1);
float float_value = (reinterpret_cast<const float*>(attr_field.data))[0];
PADDLE_ENFORCE_EQ(
float_value,
1.0,
phi::errors::InvalidArgument("float_value is invalid. "
"Expected 1.0, but received %f.",
float_value));
CHECK_EQ(float_value, 1.0);

// int_attr
attr_field = (fc->fields)[1];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kINT32,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kINT32"));
PADDLE_ENFORCE_EQ(attr_field.length,
1,
phi::errors::InvalidArgument(
"The length of attr_field is invalid. "
"Expected attr_field.length=1, but received %d.",
attr_field.length));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32);
CHECK_EQ(attr_field.length, 1);
int int_value = (reinterpret_cast<const int*>(attr_field.data))[0];
PADDLE_ENFORCE_EQ(
int_value,
1,
phi::errors::InvalidArgument("int_value is invalid. "
"Expected 1, but received %d.",
int_value));
CHECK_EQ(int_value, 1);

// bool_attr
attr_field = (fc->fields)[2];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kINT32,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kINT32"));
PADDLE_ENFORCE_EQ(attr_field.length,
1,
phi::errors::InvalidArgument(
"The length of attr_field is invalid. "
"Expected attr_field.length=1, but received %d.",
attr_field.length));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32);
CHECK_EQ(attr_field.length, 1);
int bool_value = (reinterpret_cast<const int*>(attr_field.data))[0];
PADDLE_ENFORCE_EQ(
bool_value,
1,
phi::errors::InvalidArgument("bool_value is invalid. "
"Expected 1, but received %d.",
bool_value));
CHECK_EQ(bool_value, 1);

// string_attr
attr_field = (fc->fields)[3];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kCHAR,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kCHAR"));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kCHAR);
std::string expect_string_attr = "test_string_attr";
PADDLE_ENFORCE_EQ(static_cast<size_t>(attr_field.length),
expect_string_attr.size() + 1,
phi::errors::InvalidArgument(
"The length of attr_field must be equal to "
"the size of expect_string_attr plus 1. "
"Expected %llu, but received %llu.",
static_cast<size_t>(expect_string_attr.size() + 1),
static_cast<size_t>(attr_field.length)));
CHECK_EQ((size_t)attr_field.length, expect_string_attr.size() + 1);
const char* receive_string_attr =
reinterpret_cast<const char*>(attr_field.data);
PADDLE_ENFORCE_EQ(
expect_string_attr,
std::string(receive_string_attr),
phi::errors::InvalidArgument("The received string attribute '%s' "
"does not match the expected value '%s'.",
receive_string_attr,
expect_string_attr.c_str()));
CHECK(expect_string_attr == std::string(receive_string_attr));

// ints_attr
attr_field = (fc->fields)[4];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kINT32,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kINT32"));
PADDLE_ENFORCE_EQ(attr_field.length,
3,
phi::errors::InvalidArgument(
"The length of attr_field is invalid. "
"Expected attr_field.length=3, but received %d.",
attr_field.length));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32);
CHECK_EQ(attr_field.length, 3);
const int* ints_value = reinterpret_cast<const int*>(attr_field.data);
PADDLE_ENFORCE_EQ(
ints_value[0],
1,
phi::errors::InvalidArgument("ints_value[0] is invalid. "
"Expected 1, but received %d.",
ints_value[0]));
PADDLE_ENFORCE_EQ(
ints_value[1],
2,
phi::errors::InvalidArgument("ints_value[1] is invalid. "
"Expected 2, but received %d.",
ints_value[1]));
PADDLE_ENFORCE_EQ(
ints_value[2],
3,
phi::errors::InvalidArgument("ints_value[2] is invalid. "
"Expected 3, but received %d.",
ints_value[2]));
CHECK_EQ(ints_value[0], 1);
CHECK_EQ(ints_value[1], 2);
CHECK_EQ(ints_value[2], 3);

// floats_attr
attr_field = (fc->fields)[5];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kFLOAT32,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kFLOAT32"));
PADDLE_ENFORCE_EQ(attr_field.length,
3,
phi::errors::InvalidArgument(
"The length of attr_field is invalid. "
"Expected attr_field.length=3, but received %d.",
attr_field.length));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kFLOAT32);
CHECK_EQ(attr_field.length, 3);
const float* floats_value = reinterpret_cast<const float*>(attr_field.data);
PADDLE_ENFORCE_EQ(
floats_value[0],
1.0,
phi::errors::InvalidArgument("floats_value[0] is invalid. "
"Expected 1.0, but received %f.",
floats_value[0]));
PADDLE_ENFORCE_EQ(
floats_value[1],
2.0,
phi::errors::InvalidArgument("floats_value[1] is invalid. "
"Expected 2.0, but received %f.",
floats_value[1]));
PADDLE_ENFORCE_EQ(
floats_value[2],
3.0,
phi::errors::InvalidArgument("floats_value[2] is invalid. "
"Expected 3.0, but received %f.",
floats_value[2]));
CHECK_EQ(floats_value[0], 1.0);
CHECK_EQ(floats_value[1], 2.0);
CHECK_EQ(floats_value[2], 3.0);

// bools_attr
attr_field = (fc->fields)[6];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kINT32,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kINT32"));
PADDLE_ENFORCE_EQ(attr_field.length,
3,
phi::errors::InvalidArgument(
"The length of attr_field is invalid. "
"Expected attr_field.length=3, but received %d.",
attr_field.length));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32);
CHECK_EQ(attr_field.length, 3);
ints_value = reinterpret_cast<const int*>(attr_field.data);
PADDLE_ENFORCE_EQ(
ints_value[0],
true,
phi::errors::InvalidArgument("ints_value[0] is invalid. "
"Expected true, but received false."));
PADDLE_ENFORCE_EQ(
ints_value[1],
false,
phi::errors::InvalidArgument("ints_value[1] is invalid. "
"Expected false, but received true."));
PADDLE_ENFORCE_EQ(
ints_value[2],
true,
phi::errors::InvalidArgument("ints_value[2] is invalid. "
"Expected true, but received false."));
CHECK_EQ(ints_value[0], true);
CHECK_EQ(ints_value[1], false);
CHECK_EQ(ints_value[2], true);

return new custom_op_plugin(float_value);
}
Expand Down

0 comments on commit 31dfe0a

Please sign in to comment.