Skip to content

Commit

Permalink
[Inference]Open TensorRT unittest (#67638)
Browse files Browse the repository at this point in the history
* open tensorrt unittest

* Revert "【Error Message No. 11】Update error massage of paddle\fluid\inference\tensorrt\convert\test_custom_op_plugin.h (#67429)"

This reverts commit 6d660ac.

* fix bug

* open trt download

* fix docker file

* fix timeout

* fix py3

* install tensorrt local
  • Loading branch information
YuanRisheng authored Aug 26, 2024
1 parent c1d7f52 commit c781cd2
Show file tree
Hide file tree
Showing 14 changed files with 297 additions and 406 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ option(WITH_SHARED_PHI "Compile PaddlePaddle with SHARED LIB of PHI" ON)
option(CINN_WITH_CUDNN "Compile CINN with CUDNN support" ON)
option(WITH_PIP_CUDA_LIBRARIES
"Paddle uses the CUDA library provided by NVIDIA" OFF)
option(WITH_PIP_TENSORRT "Paddle uses the tensorrt provided by NVIDIA" OFF)
option(WITH_NIGHTLY_BUILD
"Compile nightly paddle whl package of the develop branch" OFF)
option(WITH_CPP_TEST "Compile PaddlePaddle skip cpp test" ON)
Expand Down
189 changes: 28 additions & 161 deletions paddle/fluid/inference/tensorrt/convert/test_custom_op_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,196 +139,63 @@ class custom_op_plugin_creator : public nvinfer1::IPluginCreator {
nvinfer1::IPluginV2* createPlugin(
const char* name,
const nvinfer1::PluginFieldCollection* fc) noexcept override {
PADDLE_ENFORCE_EQ(
fc->nbFields,
7,
phi::errors::InvalidArgument("fc->nbFields is invalid. "
"Expected 7, but received %d.",
fc->nbFields));
CHECK_EQ(fc->nbFields, 7);
// float_attr
auto attr_field = (fc->fields)[0];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kFLOAT32,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kFLOAT32"));
PADDLE_ENFORCE_EQ(attr_field.length,
1,
phi::errors::InvalidArgument(
"The length of attr_field is invalid. "
"Expected attr_field.length=1, but received %d.",
attr_field.length));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kFLOAT32);
CHECK_EQ(attr_field.length, 1);
float float_value = (reinterpret_cast<const float*>(attr_field.data))[0];
PADDLE_ENFORCE_EQ(
float_value,
1.0,
phi::errors::InvalidArgument("float_value is invalid. "
"Expected 1.0, but received %f.",
float_value));
CHECK_EQ(float_value, 1.0);

// int_attr
attr_field = (fc->fields)[1];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kINT32,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kINT32"));
PADDLE_ENFORCE_EQ(attr_field.length,
1,
phi::errors::InvalidArgument(
"The length of attr_field is invalid. "
"Expected attr_field.length=1, but received %d.",
attr_field.length));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32);
CHECK_EQ(attr_field.length, 1);
int int_value = (reinterpret_cast<const int*>(attr_field.data))[0];
PADDLE_ENFORCE_EQ(
int_value,
1,
phi::errors::InvalidArgument("int_value is invalid. "
"Expected 1, but received %d.",
int_value));
CHECK_EQ(int_value, 1);

// bool_attr
attr_field = (fc->fields)[2];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kINT32,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kINT32"));
PADDLE_ENFORCE_EQ(attr_field.length,
1,
phi::errors::InvalidArgument(
"The length of attr_field is invalid. "
"Expected attr_field.length=1, but received %d.",
attr_field.length));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32);
CHECK_EQ(attr_field.length, 1);
int bool_value = (reinterpret_cast<const int*>(attr_field.data))[0];
PADDLE_ENFORCE_EQ(
bool_value,
1,
phi::errors::InvalidArgument("bool_value is invalid. "
"Expected 1, but received %d.",
bool_value));
CHECK_EQ(bool_value, 1);

// string_attr
attr_field = (fc->fields)[3];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kCHAR,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kCHAR"));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kCHAR);
std::string expect_string_attr = "test_string_attr";
PADDLE_ENFORCE_EQ(static_cast<size_t>(attr_field.length),
expect_string_attr.size() + 1,
phi::errors::InvalidArgument(
"The length of attr_field must be equal to "
"the size of expect_string_attr plus 1. "
"Expected %llu, but received %llu.",
static_cast<size_t>(expect_string_attr.size() + 1),
static_cast<size_t>(attr_field.length)));
CHECK_EQ((size_t)attr_field.length, expect_string_attr.size() + 1);
const char* receive_string_attr =
reinterpret_cast<const char*>(attr_field.data);
PADDLE_ENFORCE_EQ(
expect_string_attr,
std::string(receive_string_attr),
phi::errors::InvalidArgument("The received string attribute '%s' "
"does not match the expected value '%s'.",
receive_string_attr,
expect_string_attr.c_str()));
CHECK(expect_string_attr == std::string(receive_string_attr));

// ints_attr
attr_field = (fc->fields)[4];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kINT32,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kINT32"));
PADDLE_ENFORCE_EQ(attr_field.length,
3,
phi::errors::InvalidArgument(
"The length of attr_field is invalid. "
"Expected attr_field.length=3, but received %d.",
attr_field.length));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32);
CHECK_EQ(attr_field.length, 3);
const int* ints_value = reinterpret_cast<const int*>(attr_field.data);
PADDLE_ENFORCE_EQ(
ints_value[0],
1,
phi::errors::InvalidArgument("ints_value[0] is invalid. "
"Expected 1, but received %d.",
ints_value[0]));
PADDLE_ENFORCE_EQ(
ints_value[1],
2,
phi::errors::InvalidArgument("ints_value[1] is invalid. "
"Expected 2, but received %d.",
ints_value[1]));
PADDLE_ENFORCE_EQ(
ints_value[2],
3,
phi::errors::InvalidArgument("ints_value[2] is invalid. "
"Expected 3, but received %d.",
ints_value[2]));
CHECK_EQ(ints_value[0], 1);
CHECK_EQ(ints_value[1], 2);
CHECK_EQ(ints_value[2], 3);

// floats_attr
attr_field = (fc->fields)[5];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kFLOAT32,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kFLOAT32"));
PADDLE_ENFORCE_EQ(attr_field.length,
3,
phi::errors::InvalidArgument(
"The length of attr_field is invalid. "
"Expected attr_field.length=3, but received %d.",
attr_field.length));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kFLOAT32);
CHECK_EQ(attr_field.length, 3);
const float* floats_value = reinterpret_cast<const float*>(attr_field.data);
PADDLE_ENFORCE_EQ(
floats_value[0],
1.0,
phi::errors::InvalidArgument("floats_value[0] is invalid. "
"Expected 1.0, but received %f.",
floats_value[0]));
PADDLE_ENFORCE_EQ(
floats_value[1],
2.0,
phi::errors::InvalidArgument("floats_value[1] is invalid. "
"Expected 2.0, but received %f.",
floats_value[1]));
PADDLE_ENFORCE_EQ(
floats_value[2],
3.0,
phi::errors::InvalidArgument("floats_value[2] is invalid. "
"Expected 3.0, but received %f.",
floats_value[2]));
CHECK_EQ(floats_value[0], 1.0);
CHECK_EQ(floats_value[1], 2.0);
CHECK_EQ(floats_value[2], 3.0);

// bools_attr
attr_field = (fc->fields)[6];
PADDLE_ENFORCE_EQ(
attr_field.type,
nvinfer1::PluginFieldType::kINT32,
phi::errors::InvalidArgument("The attr_field type must be "
"nvinfer1::PluginFieldType::kINT32"));
PADDLE_ENFORCE_EQ(attr_field.length,
3,
phi::errors::InvalidArgument(
"The length of attr_field is invalid. "
"Expected attr_field.length=3, but received %d.",
attr_field.length));
CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32);
CHECK_EQ(attr_field.length, 3);
ints_value = reinterpret_cast<const int*>(attr_field.data);
PADDLE_ENFORCE_EQ(
ints_value[0],
true,
phi::errors::InvalidArgument("ints_value[0] is invalid. "
"Expected true, but received false."));
PADDLE_ENFORCE_EQ(
ints_value[1],
false,
phi::errors::InvalidArgument("ints_value[1] is invalid. "
"Expected false, but received true."));
PADDLE_ENFORCE_EQ(
ints_value[2],
true,
phi::errors::InvalidArgument("ints_value[2] is invalid. "
"Expected true, but received false."));
CHECK_EQ(ints_value[0], true);
CHECK_EQ(ints_value[1], false);
CHECK_EQ(ints_value[2], true);

return new custom_op_plugin(float_value);
}
Expand Down
2 changes: 1 addition & 1 deletion python/env_dict.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ env_dict={
'PADDLE_INSTALL_DIR':'@PADDLE_INSTALL_DIR@',
'PADDLE_LIB_TEST_DIR':'@PADDLE_LIB_TEST_DIR@',
'WITH_PIP_CUDA_LIBRARIES':'@WITH_PIP_CUDA_LIBRARIES@',
'TENSORRT_FOUND':'@TENSORRT_FOUND@',
'WITH_PIP_TENSORRT':'@WITH_PIP_TENSORRT@',
'TR_INFER_RT':'@TR_INFER_RT@',
'TENSORRT_LIBRARY_DIR':'@TENSORRT_LIBRARY_DIR@',
}
1 change: 0 additions & 1 deletion python/paddle/tensorrt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def map_dtype(pd_dtype):
def run_pir_pass(program, partition_mode=False):
pm = pir.PassManager(opt_level=4)
pm.enable_print_statistics()
pm.enable_ir_printing()
paddle.base.libpaddle.pir.infer_symbolic_shape_pass(pm, program)
passes = [
{'multihead_matmul_fuse_pass': {}},
Expand Down
9 changes: 6 additions & 3 deletions python/setup.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ commit = '%(commit)s'
with_mkl = '%(with_mkl)s'
cinn_version = '%(cinn)s'
with_pip_cuda_libraries = '%(with_pip_cuda_libraries)s'
with_pip_tensorrt ='%(with_pip_tensorrt)s'

__all__ = ['cuda', 'cudnn', 'nccl', 'show', 'xpu', 'xpu_xre', 'xpu_xccl', 'xpu_xhpc']

Expand Down Expand Up @@ -389,6 +390,7 @@ def cinn() -> str:
'is_tagged': is_tagged(),
'with_mkl': '@WITH_MKL@',
'cinn': get_cinn_version(),
'with_pip_tensorrt':'@WITH_PIP_TENSORRT@',
'with_pip_cuda_libraries': '@WITH_PIP_CUDA_LIBRARIES@'})

def get_cinn_config_jsons():
Expand Down Expand Up @@ -559,7 +561,7 @@ def get_paddle_extra_install_requirements():

paddle_cuda_requires = PADDLE_CUDA_INSTALL_REQUIREMENTS[cuda_major_version].split("|")

if '@TENSORRT_FOUND@' == 'ON':
if '@WITH_PIP_TENSORRT@' == 'ON':
version_str = get_tensorrt_version()
version_default = int(version_str.split(".")[0])
if platform.system() =='Linux' or (platform.system()=='Windows' and version_default>=10):
Expand Down Expand Up @@ -593,7 +595,7 @@ def get_paddle_extra_install_requirements():
)
return paddle_cuda_requires, []

return paddle_cuda_requires
return paddle_cuda_requires,paddle_tensorrt_requires



Expand Down Expand Up @@ -790,8 +792,9 @@ if sys.version_info >= (3,8):
setup_requires_tmp+=[setup_requires_i]
setup_requires = setup_requires_tmp
if '@WITH_GPU@' == 'ON' and platform.system() in ('Linux', 'Windows') and platform.machine() in ('x86_64', 'AMD64'):
paddle_cuda_requires= get_paddle_extra_install_requirements()
paddle_cuda_requires,paddle_tensorrt_requires= get_paddle_extra_install_requirements()
setup_requires += paddle_cuda_requires
setup_requires += paddle_tensorrt_requires


# the prefix is sys.prefix which should always be usr
Expand Down
11 changes: 8 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def write_version_py(filename='paddle/version/__init__.py'):
with_mkl = '%(with_mkl)s'
cinn_version = '%(cinn)s'
with_pip_cuda_libraries = '%(with_pip_cuda_libraries)s'
with_pip_tensorrt ='%(with_pip_tensorrt)s'
__all__ = ['cuda', 'cudnn', 'nccl', 'show', 'xpu', 'xpu_xre', 'xpu_xccl', 'xpu_xhpc']
Expand Down Expand Up @@ -713,6 +714,7 @@ def cinn() -> str:
'with_pip_cuda_libraries': env_dict.get(
"WITH_PIP_CUDA_LIBRARIES"
),
'with_pip_tensorrt': env_dict.get("WITH_PIP_TENSORRT"),
}
)

Expand Down Expand Up @@ -1079,7 +1081,7 @@ def get_paddle_extra_install_requirements():
cuda_major_version
].split("|")

if env_dict.get("TENSORRT_FOUND") == "ON":
if env_dict.get("WITH_PIP_TENSORRT") == "ON":
version_str = get_tensorrt_version()
version_default = int(version_str.split(".")[0])
if platform.system() == 'Linux' or (
Expand Down Expand Up @@ -1120,7 +1122,7 @@ def get_paddle_extra_install_requirements():
)
return paddle_cuda_requires, []

return paddle_cuda_requires
return paddle_cuda_requires, paddle_tensorrt_requires


def get_cinn_config_jsons():
Expand Down Expand Up @@ -1711,8 +1713,11 @@ def get_setup_parameters():
'AMD64',
)
):
paddle_cuda_requires = get_paddle_extra_install_requirements()
paddle_cuda_requires, paddle_tensorrt_requires = (
get_paddle_extra_install_requirements()
)
setup_requires += paddle_cuda_requires
setup_requires += paddle_tensorrt_requires

packages = [
'paddle',
Expand Down
22 changes: 13 additions & 9 deletions test/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Open after the unilateral issue is resolved
# file(
# GLOB TEST_OPS
# RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
# "test_*.py")
# string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
if(NOT WIN32 AND TENSORRT_FOUND)
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")

# foreach(TEST_OP ${TEST_OPS})
# py_test_modules(${TEST_OP} MODULES ${TEST_OP})
# endforeach()
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach()
set_tests_properties(test_converter_model_bert PROPERTIES TIMEOUT "500")
set_tests_properties(test_converter_model_dummy PROPERTIES TIMEOUT "500")
set_tests_properties(test_converter_model_resnet50 PROPERTIES TIMEOUT "500")
endif()
Loading

0 comments on commit c781cd2

Please sign in to comment.