Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle TensorRT No.8] pd_op.anchor_generator #70667

Merged
merged 6 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 285 additions & 0 deletions paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,291 @@ nvinfer1::IPluginV2Ext* AnchorGeneratorPluginDynamicCreator::deserializePlugin(
}
#endif

PIRAnchorGeneratorPluginDynamic::PIRAnchorGeneratorPluginDynamic(
const nvinfer1::DataType data_type,
const std::vector<float>& anchor_sizes,
const std::vector<float>& aspect_ratios,
const std::vector<float>& stride,
const std::vector<float>& variances,
const float offset,
const int num_anchors)
: data_type_(data_type),
anchor_sizes_(anchor_sizes),
aspect_ratios_(aspect_ratios),
stride_(stride),
variances_(variances),
offset_(offset),
num_anchors_(num_anchors) {
// data_type_ is used to determine the output data type
// data_type_ can only be float32
// height, width, num_anchors are calculated at configurePlugin
PADDLE_ENFORCE_EQ(data_type_,
nvinfer1::DataType::kFLOAT,
common::errors::InvalidArgument(
"TRT anchor generator plugin only accepts float32."));
PADDLE_ENFORCE_GE(
num_anchors_,
0,
common::errors::InvalidArgument(
"TRT anchor generator plugin only accepts number of anchors greater "
"than 0, but receive number of anchors = %d.",
num_anchors_));
PrepareParamsOnDevice();
}

PIRAnchorGeneratorPluginDynamic::~PIRAnchorGeneratorPluginDynamic() {
auto release_device_ptr = [](void* ptr) {
if (ptr) {
cudaFree(ptr);
ptr = nullptr;
}
};
release_device_ptr(anchor_sizes_device_);
release_device_ptr(aspect_ratios_device_);
release_device_ptr(stride_device_);
release_device_ptr(variances_device_);
}

PIRAnchorGeneratorPluginDynamic::PIRAnchorGeneratorPluginDynamic(
void const* data, size_t length) {
DeserializeValue(&data, &length, &data_type_);
DeserializeValue(&data, &length, &anchor_sizes_);
DeserializeValue(&data, &length, &aspect_ratios_);
DeserializeValue(&data, &length, &stride_);
DeserializeValue(&data, &length, &variances_);
DeserializeValue(&data, &length, &offset_);
DeserializeValue(&data, &length, &num_anchors_);
PrepareParamsOnDevice();
}

nvinfer1::IPluginV2DynamicExt* PIRAnchorGeneratorPluginDynamic::clone() const
TRT_NOEXCEPT {
auto plugin = new PIRAnchorGeneratorPluginDynamic(data_type_,
anchor_sizes_,
aspect_ratios_,
stride_,
variances_,
offset_,
num_anchors_);
plugin->setPluginNamespace(namespace_.c_str());
return plugin;
}

nvinfer1::DimsExprs PIRAnchorGeneratorPluginDynamic::getOutputDimensions(
int outputIndex,
const nvinfer1::DimsExprs* inputs,
int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs ret{};
ret.nbDims = 4;
ret.d[0] = inputs[0].d[2]; // feature height
ret.d[1] = inputs[0].d[3]; // feature width
ret.d[2] = exprBuilder.constant(num_anchors_);
ret.d[3] = exprBuilder.constant(4);
return ret;
}

bool PIRAnchorGeneratorPluginDynamic::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
// input can be any, doesn't matter
// anchor generator doesn't read input raw data, only need the shape info
auto type = inOut[pos].type;
auto format = inOut[pos].format;
#if IS_TRT_VERSION_GE(7234)
if (pos == 0) return true;
#else
if (pos == 0) return format == nvinfer1::TensorFormat::kLINEAR;
#endif
return (type == nvinfer1::DataType::kFLOAT &&
format == nvinfer1::TensorFormat::kLINEAR);
}

void PIRAnchorGeneratorPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT {}

size_t PIRAnchorGeneratorPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT {
return 0;
}

template <typename T>
int PIRAnchorGeneratorPluginDynamic::enqueue_impl(
const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) {
const int height = inputDesc[0].dims.d[2];
const int width = inputDesc[0].dims.d[3];
const int box_num = height * width * num_anchors_;
const int block = 512;
const int gen_anchor_grid = (box_num + block - 1) / block;
T* anchors = static_cast<T*>(outputs[0]);
T* vars = static_cast<T*>(outputs[1]);
const T* anchor_sizes_device = static_cast<const T*>(anchor_sizes_device_);
const T* aspect_ratios_device = static_cast<const T*>(aspect_ratios_device_);
const T* stride_device = static_cast<const T*>(stride_device_);
const T* variances_device = static_cast<const T*>(variances_device_);
phi::GenAnchors<T>
<<<gen_anchor_grid, block, 0, stream>>>(anchors,
aspect_ratios_device,
aspect_ratios_.size(),
anchor_sizes_device,
anchor_sizes_.size(),
stride_device,
stride_.size(),
height,
width,
offset_);
const int var_grid = (box_num * 4 + block - 1) / block;
phi::SetVariance<T><<<var_grid, block, 0, stream>>>(
vars, variances_device, variances_.size(), box_num * 4);
return cudaGetLastError() != cudaSuccess;
}

int PIRAnchorGeneratorPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
assert(outputDesc[0].type == nvinfer1::DataType::kFLOAT);
assert(outputDesc[1].type == nvinfer1::DataType::kFLOAT);
return enqueue_impl<float>(
inputDesc, outputDesc, inputs, outputs, workspace, stream);
}

nvinfer1::DataType PIRAnchorGeneratorPluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}

const char* PIRAnchorGeneratorPluginDynamic::getPluginType() const
TRT_NOEXCEPT {
return "pir_anchor_generator_plugin_dynamic";
}

int PIRAnchorGeneratorPluginDynamic::getNbOutputs() const TRT_NOEXCEPT {
return 2;
}

int PIRAnchorGeneratorPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }

void PIRAnchorGeneratorPluginDynamic::terminate() TRT_NOEXCEPT {}

size_t PIRAnchorGeneratorPluginDynamic::getSerializationSize() const
TRT_NOEXCEPT {
size_t serialize_size = 0;
serialize_size += SerializedSize(data_type_);
serialize_size += SerializedSize(anchor_sizes_);
serialize_size += SerializedSize(aspect_ratios_);
serialize_size += SerializedSize(stride_);
serialize_size += SerializedSize(variances_);
serialize_size += SerializedSize(offset_);
serialize_size += SerializedSize(num_anchors_);
return serialize_size;
}

void PIRAnchorGeneratorPluginDynamic::serialize(void* buffer) const
TRT_NOEXCEPT {
SerializeValue(&buffer, data_type_);
SerializeValue(&buffer, anchor_sizes_);
SerializeValue(&buffer, aspect_ratios_);
SerializeValue(&buffer, stride_);
SerializeValue(&buffer, variances_);
SerializeValue(&buffer, offset_);
SerializeValue(&buffer, num_anchors_);
}

void PIRAnchorGeneratorPluginDynamic::destroy() TRT_NOEXCEPT {}

void PIRAnchorGeneratorPluginDynamicCreator::setPluginNamespace(
const char* lib_namespace) TRT_NOEXCEPT {
namespace_ = std::string(lib_namespace);
}

const char* PIRAnchorGeneratorPluginDynamicCreator::getPluginNamespace() const
TRT_NOEXCEPT {
return namespace_.c_str();
}

const char* PIRAnchorGeneratorPluginDynamicCreator::getPluginName() const
TRT_NOEXCEPT {
return "pir_anchor_generator_plugin_dynamic";
}

const char* PIRAnchorGeneratorPluginDynamicCreator::getPluginVersion() const
TRT_NOEXCEPT {
return "1";
}

const nvinfer1::PluginFieldCollection*
PIRAnchorGeneratorPluginDynamicCreator::getFieldNames() TRT_NOEXCEPT {
return &field_collection_;
}

nvinfer1::IPluginV2Ext* PIRAnchorGeneratorPluginDynamicCreator::createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT {
const nvinfer1::PluginField* fields = fc->fields;
std::vector<float> anchor_sizes, aspect_ratios, stride, variances;
float offset = .5;
int num_anchors = -1;

for (int i = 0; i < fc->nbFields; ++i) {
const nvinfer1::PluginField& f = fc->fields[i];
const std::string field_name(f.name);
if (field_name.compare("anchor_sizes") == 0) {
const float* data = static_cast<const float*>(f.data);
anchor_sizes.assign(data, data + f.length);
} else if (field_name.compare("aspect_ratios") == 0) {
const float* data = static_cast<const float*>(f.data);
aspect_ratios.assign(data, data + f.length);
} else if (field_name.compare("stride") == 0) {
const float* data = static_cast<const float*>(f.data);
stride.assign(data, data + f.length);
} else if (field_name.compare("variances") == 0) {
const float* data = static_cast<const float*>(f.data);
variances.assign(data, data + f.length);
} else if (field_name.compare("offset") == 0) {
offset = *static_cast<const float*>(f.data);
} else if (field_name.compare("num_anchors") == 0) {
num_anchors = *static_cast<const int*>(f.data);
} else {
assert(false && "unknown plugin field name.");
}
}
return new PIRAnchorGeneratorPluginDynamic(nvinfer1::DataType::kFLOAT,
anchor_sizes,
aspect_ratios,
stride,
variances,
offset,
num_anchors);
}

nvinfer1::IPluginV2Ext*
PIRAnchorGeneratorPluginDynamicCreator::deserializePlugin(
const char* name,
const void* serial_data,
size_t serial_length) TRT_NOEXCEPT {
auto plugin = new PIRAnchorGeneratorPluginDynamic(serial_data, serial_length);
plugin->setPluginNamespace(namespace_.c_str());
return plugin;
}

} // namespace plugin
} // namespace tensorrt
} // namespace inference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,105 @@ class AnchorGeneratorPluginDynamicCreator : public nvinfer1::IPluginCreator {
std::string namespace_;
nvinfer1::PluginFieldCollection field_collection_;
};

class PIRAnchorGeneratorPluginDynamic : public DynamicPluginTensorRT {
public:
explicit PIRAnchorGeneratorPluginDynamic(
const nvinfer1::DataType data_type,
const std::vector<float>& anchor_sizes,
const std::vector<float>& aspect_ratios,
const std::vector<float>& stride,
const std::vector<float>& variances,
const float offset,
const int num_anchors);
PIRAnchorGeneratorPluginDynamic(void const* data, size_t length);
~PIRAnchorGeneratorPluginDynamic();
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex,
const nvinfer1::DimsExprs* inputs,
int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) // NOLINT
TRT_NOEXCEPT override;

bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;

void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override;

size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const
TRT_NOEXCEPT override;
const char* getPluginType() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override;

private:
template <typename T>
int enqueue_impl(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream);
nvinfer1::DataType data_type_;
std::vector<float> anchor_sizes_;
std::vector<float> aspect_ratios_;
std::vector<float> stride_;
std::vector<float> variances_;
float offset_;
void* anchor_sizes_device_;
void* aspect_ratios_device_;
void* stride_device_;
void* variances_device_;
int num_anchors_;
std::string namespace_;
};

class PIRAnchorGeneratorPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
PIRAnchorGeneratorPluginDynamicCreator() = default;
~PIRAnchorGeneratorPluginDynamicCreator() override = default;
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override;
const char* getPluginNamespace() const TRT_NOEXCEPT override;
const char* getPluginName() const TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override;
nvinfer1::IPluginV2Ext* createPlugin(
const char* name,
const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override;
nvinfer1::IPluginV2Ext* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override;

private:
std::string namespace_;
nvinfer1::PluginFieldCollection field_collection_;
};

REGISTER_TRT_PLUGIN_V2(AnchorGeneratorPluginDynamicCreator);
REGISTER_TRT_PLUGIN_V2(PIRAnchorGeneratorPluginDynamicCreator);
#endif

} // namespace plugin
Expand Down
Loading
Loading