diff --git a/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu index f7adaab13d1167..2378e8e11097b7 100644 --- a/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu @@ -667,6 +667,291 @@ nvinfer1::IPluginV2Ext* AnchorGeneratorPluginDynamicCreator::deserializePlugin( } #endif +PIRAnchorGeneratorPluginDynamic::PIRAnchorGeneratorPluginDynamic( + const nvinfer1::DataType data_type, + const std::vector& anchor_sizes, + const std::vector& aspect_ratios, + const std::vector& stride, + const std::vector& variances, + const float offset, + const int num_anchors) + : data_type_(data_type), + anchor_sizes_(anchor_sizes), + aspect_ratios_(aspect_ratios), + stride_(stride), + variances_(variances), + offset_(offset), + num_anchors_(num_anchors) { + // data_type_ is used to determine the output data type + // data_type_ can only be float32 + // height, width, num_anchors are calculated at configurePlugin + PADDLE_ENFORCE_EQ(data_type_, + nvinfer1::DataType::kFLOAT, + common::errors::InvalidArgument( + "TRT anchor generator plugin only accepts float32.")); + PADDLE_ENFORCE_GE( + num_anchors_, + 0, + common::errors::InvalidArgument( + "TRT anchor generator plugin only accepts number of anchors greater " + "than 0, but receive number of anchors = %d.", + num_anchors_)); + PrepareParamsOnDevice(); +} + +PIRAnchorGeneratorPluginDynamic::~PIRAnchorGeneratorPluginDynamic() { + auto release_device_ptr = [](void* ptr) { + if (ptr) { + cudaFree(ptr); + ptr = nullptr; + } + }; + release_device_ptr(anchor_sizes_device_); + release_device_ptr(aspect_ratios_device_); + release_device_ptr(stride_device_); + release_device_ptr(variances_device_); +} + +PIRAnchorGeneratorPluginDynamic::PIRAnchorGeneratorPluginDynamic( + void const* data, size_t length) { + DeserializeValue(&data, &length, &data_type_); + DeserializeValue(&data, &length, &anchor_sizes_); + DeserializeValue(&data, &length, &aspect_ratios_); + DeserializeValue(&data, &length, &stride_); + DeserializeValue(&data, &length, &variances_); + DeserializeValue(&data, &length, &offset_); + DeserializeValue(&data, &length, &num_anchors_); + PrepareParamsOnDevice(); +} + +nvinfer1::IPluginV2DynamicExt* PIRAnchorGeneratorPluginDynamic::clone() const + TRT_NOEXCEPT { + auto plugin = new PIRAnchorGeneratorPluginDynamic(data_type_, + anchor_sizes_, + aspect_ratios_, + stride_, + variances_, + offset_, + num_anchors_); + plugin->setPluginNamespace(namespace_.c_str()); + return plugin; +} + +nvinfer1::DimsExprs PIRAnchorGeneratorPluginDynamic::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { + nvinfer1::DimsExprs ret{}; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[2]; // feature height + ret.d[1] = inputs[0].d[3]; // feature width + ret.d[2] = exprBuilder.constant(num_anchors_); + ret.d[3] = exprBuilder.constant(4); + return ret; +} + +bool PIRAnchorGeneratorPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT { + // input can be any, doesn't matter + // anchor generator doesn't read input raw data, only need the shape info + auto type = inOut[pos].type; + auto format = inOut[pos].format; +#if IS_TRT_VERSION_GE(7234) + if (pos == 0) return true; +#else + if (pos == 0) return format == nvinfer1::TensorFormat::kLINEAR; +#endif + return (type == nvinfer1::DataType::kFLOAT && + format == nvinfer1::TensorFormat::kLINEAR); +} + +void PIRAnchorGeneratorPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT {} + +size_t PIRAnchorGeneratorPluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT { + return 0; +} + +template +int PIRAnchorGeneratorPluginDynamic::enqueue_impl( + const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) { + const int height = inputDesc[0].dims.d[2]; + const int width = inputDesc[0].dims.d[3]; + const int box_num = height * width * num_anchors_; + const int block = 512; + const int gen_anchor_grid = (box_num + block - 1) / block; + T* anchors = static_cast(outputs[0]); + T* vars = static_cast(outputs[1]); + const T* anchor_sizes_device = static_cast(anchor_sizes_device_); + const T* aspect_ratios_device = static_cast(aspect_ratios_device_); + const T* stride_device = static_cast(stride_device_); + const T* variances_device = static_cast(variances_device_); + phi::GenAnchors + <<>>(anchors, + aspect_ratios_device, + aspect_ratios_.size(), + anchor_sizes_device, + anchor_sizes_.size(), + stride_device, + stride_.size(), + height, + width, + offset_); + const int var_grid = (box_num * 4 + block - 1) / block; + phi::SetVariance<<>>( + vars, variances_device, variances_.size(), box_num * 4); + return cudaGetLastError() != cudaSuccess; +} + +int PIRAnchorGeneratorPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT { + assert(outputDesc[0].type == nvinfer1::DataType::kFLOAT); + assert(outputDesc[1].type == nvinfer1::DataType::kFLOAT); + return enqueue_impl( + inputDesc, outputDesc, inputs, outputs, workspace, stream); +} + +nvinfer1::DataType PIRAnchorGeneratorPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT { + return inputTypes[0]; +} + +const char* PIRAnchorGeneratorPluginDynamic::getPluginType() const + TRT_NOEXCEPT { + return "pir_anchor_generator_plugin_dynamic"; +} + +int PIRAnchorGeneratorPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { + return 2; +} + +int PIRAnchorGeneratorPluginDynamic::initialize() TRT_NOEXCEPT { return 0; } + +void PIRAnchorGeneratorPluginDynamic::terminate() TRT_NOEXCEPT {} + +size_t PIRAnchorGeneratorPluginDynamic::getSerializationSize() const + TRT_NOEXCEPT { + size_t serialize_size = 0; + serialize_size += SerializedSize(data_type_); + serialize_size += SerializedSize(anchor_sizes_); + serialize_size += SerializedSize(aspect_ratios_); + serialize_size += SerializedSize(stride_); + serialize_size += SerializedSize(variances_); + serialize_size += SerializedSize(offset_); + serialize_size += SerializedSize(num_anchors_); + return serialize_size; +} + +void PIRAnchorGeneratorPluginDynamic::serialize(void* buffer) const + TRT_NOEXCEPT { + SerializeValue(&buffer, data_type_); + SerializeValue(&buffer, anchor_sizes_); + SerializeValue(&buffer, aspect_ratios_); + SerializeValue(&buffer, stride_); + SerializeValue(&buffer, variances_); + SerializeValue(&buffer, offset_); + SerializeValue(&buffer, num_anchors_); +} + +void PIRAnchorGeneratorPluginDynamic::destroy() TRT_NOEXCEPT {} + +void PIRAnchorGeneratorPluginDynamicCreator::setPluginNamespace( + const char* lib_namespace) TRT_NOEXCEPT { + namespace_ = std::string(lib_namespace); +} + +const char* PIRAnchorGeneratorPluginDynamicCreator::getPluginNamespace() const + TRT_NOEXCEPT { + return namespace_.c_str(); +} + +const char* PIRAnchorGeneratorPluginDynamicCreator::getPluginName() const + TRT_NOEXCEPT { + return "pir_anchor_generator_plugin_dynamic"; +} + +const char* PIRAnchorGeneratorPluginDynamicCreator::getPluginVersion() const + TRT_NOEXCEPT { + return "1"; +} + +const nvinfer1::PluginFieldCollection* +PIRAnchorGeneratorPluginDynamicCreator::getFieldNames() TRT_NOEXCEPT { + return &field_collection_; +} + +nvinfer1::IPluginV2Ext* PIRAnchorGeneratorPluginDynamicCreator::createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT { + const nvinfer1::PluginField* fields = fc->fields; + std::vector anchor_sizes, aspect_ratios, stride, variances; + float offset = .5; + int num_anchors = -1; + + for (int i = 0; i < fc->nbFields; ++i) { + const nvinfer1::PluginField& f = fc->fields[i]; + const std::string field_name(f.name); + if (field_name.compare("anchor_sizes") == 0) { + const float* data = static_cast(f.data); + anchor_sizes.assign(data, data + f.length); + } else if (field_name.compare("aspect_ratios") == 0) { + const float* data = static_cast(f.data); + aspect_ratios.assign(data, data + f.length); + } else if (field_name.compare("stride") == 0) { + const float* data = static_cast(f.data); + stride.assign(data, data + f.length); + } else if (field_name.compare("variances") == 0) { + const float* data = static_cast(f.data); + variances.assign(data, data + f.length); + } else if (field_name.compare("offset") == 0) { + offset = *static_cast(f.data); + } else if (field_name.compare("num_anchors") == 0) { + num_anchors = *static_cast(f.data); + } else { + assert(false && "unknown plugin field name."); + } + } + return new PIRAnchorGeneratorPluginDynamic(nvinfer1::DataType::kFLOAT, + anchor_sizes, + aspect_ratios, + stride, + variances, + offset, + num_anchors); +} + +nvinfer1::IPluginV2Ext* +PIRAnchorGeneratorPluginDynamicCreator::deserializePlugin( + const char* name, + const void* serial_data, + size_t serial_length) TRT_NOEXCEPT { + auto plugin = new PIRAnchorGeneratorPluginDynamic(serial_data, serial_length); + plugin->setPluginNamespace(namespace_.c_str()); + return plugin; +} + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.h index 72f11c76767ebb..20f145e9095694 100644 --- a/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.h @@ -227,7 +227,105 @@ class AnchorGeneratorPluginDynamicCreator : public nvinfer1::IPluginCreator { std::string namespace_; nvinfer1::PluginFieldCollection field_collection_; }; + +class PIRAnchorGeneratorPluginDynamic : public DynamicPluginTensorRT { + public: + explicit PIRAnchorGeneratorPluginDynamic( + const nvinfer1::DataType data_type, + const std::vector& anchor_sizes, + const std::vector& aspect_ratios, + const std::vector& stride, + const std::vector& variances, + const float offset, + const int num_anchors); + PIRAnchorGeneratorPluginDynamic(void const* data, size_t length); + ~PIRAnchorGeneratorPluginDynamic(); + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) // NOLINT + TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const + TRT_NOEXCEPT override; + const char* getPluginType() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + int initialize() TRT_NOEXCEPT override; + void terminate() TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + void destroy() TRT_NOEXCEPT override; + + private: + template + int enqueue_impl(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream); + nvinfer1::DataType data_type_; + std::vector anchor_sizes_; + std::vector aspect_ratios_; + std::vector stride_; + std::vector variances_; + float offset_; + void* anchor_sizes_device_; + void* aspect_ratios_device_; + void* stride_device_; + void* variances_device_; + int num_anchors_; + std::string namespace_; +}; + +class PIRAnchorGeneratorPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + PIRAnchorGeneratorPluginDynamicCreator() = default; + ~PIRAnchorGeneratorPluginDynamicCreator() override = default; + void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override; + const char* getPluginNamespace() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override; + nvinfer1::IPluginV2Ext* createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; + nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override; + + private: + std::string namespace_; + nvinfer1::PluginFieldCollection field_collection_; +}; + REGISTER_TRT_PLUGIN_V2(AnchorGeneratorPluginDynamicCreator); +REGISTER_TRT_PLUGIN_V2(PIRAnchorGeneratorPluginDynamicCreator); #endif } // namespace plugin diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 0ad509a9601882..78eeb58a19133d 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -94,6 +94,7 @@ DEFINE_GENERAL_PATTERN(Flip, paddle::dialect::FlipOp) DEFINE_GENERAL_PATTERN(Mish, paddle::dialect::MishOp) DEFINE_GENERAL_PATTERN(AssignValue, paddle::dialect::AssignValueOp) DEFINE_GENERAL_PATTERN(AssignValue_, paddle::dialect::AssignValue_Op) +DEFINE_GENERAL_PATTERN(Anchor_Generator, paddle::dialect::AnchorGeneratorOp) DEFINE_GENERAL_PATTERN(Exp, paddle::dialect::ExpOp) DEFINE_GENERAL_PATTERN(Abs, paddle::dialect::AbsOp) DEFINE_GENERAL_PATTERN(Abs_, paddle::dialect::Abs_Op) @@ -2294,6 +2295,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ADD_PATTERN(Mish) ADD_PATTERN(AssignValue) ADD_PATTERN(AssignValue_) + ADD_PATTERN(Anchor_Generator) ADD_PATTERN(Exp) ADD_PATTERN(Abs) ADD_PATTERN(Abs_) diff --git a/paddle/fluid/pybind/manual_static_op_function.h b/paddle/fluid/pybind/manual_static_op_function.h index 4363fc6c8630d5..f0cf95ee7f66fb 100644 --- a/paddle/fluid/pybind/manual_static_op_function.h +++ b/paddle/fluid/pybind/manual_static_op_function.h @@ -28,6 +28,7 @@ #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/op_callstack_utils.h" #include "paddle/fluid/pybind/op_function_common.h" +#include "paddle/fluid/pybind/static_op_function.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/infermeta/spmd_rules/rules.h" @@ -1188,6 +1189,18 @@ static PyObject *fused_gemm_epilogue(PyObject *self, } } +static PyObject *anchor_generator(PyObject *self, + PyObject *args, + PyObject *kwargs) { + if (egr::Controller::Instance().GetCurrentTracer() == nullptr) { + VLOG(6) << "Call static_api_anchor_generator"; + return static_api_anchor_generator(self, args, kwargs); + } else { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + static PyObject *share_var(PyObject *self, PyObject *args, PyObject *kwargs) { try { VLOG(6) << "Add share_var op into program"; @@ -1267,6 +1280,10 @@ static PyMethodDef ManualOpsAPI[] = { (PyCFunction)(void (*)(void))fused_gemm_epilogue, METH_VARARGS | METH_KEYWORDS, "C++ interface function for fused_gemm_epilogue."}, + {"anchor_generator", + (PyCFunction)(void (*)(void))anchor_generator, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for anchor_generator."}, {"_run_custom_op", (PyCFunction)(void (*)(void))run_custom_op, METH_VARARGS | METH_KEYWORDS, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 61c0a8e55ecb2f..b59e431a8480d4 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -235,6 +235,7 @@ limitations under the License. */ #include "pybind11/stl.h" #ifdef PADDLE_WITH_TENSORRT #include "paddle/fluid/inference/tensorrt/pir/declare_plugin.h" +#include "paddle/fluid/platform/tensorrt/trt_plugin.h" #endif COMMON_DECLARE_bool(use_mkldnn); @@ -3422,6 +3423,10 @@ All parameter, weight, gradient are variables in Paddle. m.def("clear_shape_info", []() { paddle::framework::CollectShapeManager::Instance().ClearShapeInfo(); }); +#ifdef PADDLE_WITH_TENSORRT + m.def("register_paddle_plugin", + []() { paddle::platform::TrtPluginRegistry::Global()->RegistToTrt(); }); +#endif #if defined(PADDLE_WITH_PSLIB) && !defined(PADDLE_WITH_HETERPS) BindHeterWrapper(&m); diff --git a/python/paddle/tensorrt/converter.py b/python/paddle/tensorrt/converter.py index cab46618c4c0ee..3e7b32d400042b 100644 --- a/python/paddle/tensorrt/converter.py +++ b/python/paddle/tensorrt/converter.py @@ -17,6 +17,10 @@ import logging import numpy as np + +import paddle + +paddle.base.core.register_paddle_plugin() import tensorrt as trt import paddle diff --git a/python/paddle/tensorrt/impls/others.py b/python/paddle/tensorrt/impls/others.py index f2f571f6953129..8f9cafbccf758c 100644 --- a/python/paddle/tensorrt/impls/others.py +++ b/python/paddle/tensorrt/impls/others.py @@ -303,6 +303,66 @@ def share_data_converter(network, paddle_op, inputs): return identity_layer.get_output(0) +@converter_registry.register("pd_op.anchor_generator", trt_version="8.x") +def anchor_generator_converter(network, paddle_op, inputs): + inputs = inputs[0] + input_dims = inputs.shape + anchor_sizes = paddle_op.attrs().get("anchor_sizes") + aspect_ratios = paddle_op.attrs().get("aspect_ratios") + stride = paddle_op.attrs().get("stride") + variances = paddle_op.attrs().get("variances") + offset = paddle_op.attrs().get("offset") + num_anchors = len(aspect_ratios) * len(anchor_sizes) + + height = input_dims[1] + width = input_dims[2] + box_num = width * height * num_anchors + data_type = trt.float32 + + plugin_fields = [ + trt.PluginField( + "anchor_sizes", + np.array(anchor_sizes, dtype=np.float32), + trt.PluginFieldType.FLOAT32, + ), + trt.PluginField( + "aspect_ratios", + np.array(aspect_ratios, dtype=np.float32), + trt.PluginFieldType.FLOAT32, + ), + trt.PluginField( + "stride", + np.array(stride, dtype=np.float32), + trt.PluginFieldType.FLOAT32, + ), + trt.PluginField( + "variances", + np.array(variances, dtype=np.float32), + trt.PluginFieldType.FLOAT32, + ), + trt.PluginField( + "offset", + np.array(offset, dtype=np.float32), + trt.PluginFieldType.FLOAT32, + ), + trt.PluginField( + "num_anchors", + np.array(num_anchors, dtype=np.int32), + trt.PluginFieldType.INT32, + ), + ] + plugin_field_collection = trt.PluginFieldCollection(plugin_fields) + plugin_name = "pir_anchor_generator_plugin_dynamic" + plugin_version = "1" + plugin = get_trt_plugin( + plugin_name, plugin_field_collection, plugin_version + ) + anchor_generator_layer = network.add_plugin_v2([inputs], plugin) + out0 = anchor_generator_layer.get_output(0) + out1 = anchor_generator_layer.get_output(1) + return (out0, out1) + + @converter_registry.register("pd_op.affine_channel", trt_version="8.x") def affine_channel_converter(network, paddle_op, inputs): x, scale_weights, bias_weights = inputs diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index 0c88733296f262..8b201467137eec 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -437,7 +437,7 @@ def test_fp16_trt_result(self): self.check_trt_result(precision_mode="fp16") -class TestAffineChannelCas1TRTPattern(TensorRTBaseTest): +class TestAffineChannelCase1TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = affine_channel self.api_args = { @@ -458,5 +458,57 @@ def test_fp16_trt_result(self): self.check_trt_result(precision_mode="fp16") +def anchor_generator(x, anchor_sizes, aspect_ratios, variances, stride, offset): + return _C_ops.anchor_generator( + x, anchor_sizes, aspect_ratios, variances, stride, offset + ) + + +class TestAnchorGeneratorTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = anchor_generator + self.api_args = { + "x": np.random.random((2, 3, 3, 100)).astype("float32"), + "anchor_sizes": [64.0, 128.0, 256.0], + "aspect_ratios": [0.5, 1, 2], + "variances": [1.0, 1.0, 1.0, 1.0], + "stride": [16.0, 16.0], + "offset": 0.5, + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [1, 3, 3, 100]} + self.opt_shape = {"x": [2, 3, 3, 100]} + self.max_shape = {"x": [3, 3, 3, 100]} + + def test_fp32_trt_result(self): + self.check_trt_result() + + def test_fp16_trt_result(self): + self.check_trt_result(precision_mode="fp16") + + +class TestAnchorGeneratorCase1TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = anchor_generator + self.api_args = { + "x": np.random.random((2, 3, 64, 64)).astype("float32"), + "anchor_sizes": [64.0, 128.0, 256.0], + "aspect_ratios": [0.4, 1.2, 3], + "variances": [0.5, 1.0, 0.5, 1.0], + "stride": [16.0, 32.0], + "offset": 0.8, + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 3, 64, 64]} + self.opt_shape = {"x": [2, 3, 64, 64]} + self.max_shape = {"x": [3, 3, 64, 64]} + + def test_fp32_trt_result(self): + self.check_trt_result() + + def test_fp16_trt_result(self): + self.check_trt_result(precision_mode="fp16") + + if __name__ == '__main__': unittest.main()