Skip to content

Commit

Permalink
支持动态shape
Browse files Browse the repository at this point in the history
  • Loading branch information
zhink committed Jan 4, 2024
1 parent 58ca933 commit 1c31c9a
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 114 deletions.
197 changes: 154 additions & 43 deletions paddle/fluid/inference/tensorrt/convert/set_value_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@ limitations under the License. */
if (op_desc.HasAttr(#attr_name__)) { \
vec_##attr_name__ = PADDLE_GET_CONST(std::vector<int64_t>, \
op_desc.GetAttr(#attr_name__)); \
if (!vec_##attr_name__.empty()) attr_name__ = vec_##attr_name__[0]; \
if (vec_##attr_name__.size() > 0) { \
attr_name__ = vec_##attr_name__[0]; \
PADDLE_ENFORCE_EQ(vec_##attr_name__.size(), \
1UL, \
platform::errors::InvalidArgument( \
"attr axes/starst/ends/steps 's size in " \
"set_value must be one, but got %d", \
vec_##attr_name__.size())); \
} \
} \
} while (0)

Expand All @@ -44,83 +52,186 @@ class SetValueConverter : public OpConverter {
framework::OpDesc op_desc(op, nullptr);

auto* inputs = engine_->GetITensor(op_desc.Input("Input")[0]);
auto* updates = engine_->GetITensor(op_desc.Input("ValueTensor")[0]);
auto output_name = op_desc.Output("Out")[0];
nvinfer1::ITensor* updates;
if (op_desc.Input("ValueTensor").size() > 0) {
updates = engine_->GetITensor(op_desc.Input("ValueTensor")[0]);
} else {
PADDLE_ENFORCE_EQ(PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")), 5);
float value = PADDLE_GET_CONST(std::vector<paddle::experimental::Scalar>,
op_desc.GetAttr("values"))[0]
.to<int>();
VLOG(3) << "the attribute value is: " << value;
nvinfer1::Dims tmp_dim;
tmp_dim.nbDims = inputs->getDimensions().nbDims;
for (int i = 0; i < tmp_dim.nbDims; i++) tmp_dim.d[i] = 1;
updates = AddConstantLayer(&value, tmp_dim);
}

// for log
{
nvinfer1::Dims tmp_dims = inputs->getDimensions();
std::vector<int> tmp_vec;
for (int i = 0; i < tmp_dims.nbDims; i++)
tmp_vec.push_back(tmp_dims.d[i]);
VLOG(3) << "Input(Name:" << op_desc.Input("Input")[0] << ")"
<< "'s dimension is :[" << string::join_strings(tmp_vec, ',')
<< "]";

tmp_vec.clear();
tmp_dims = updates->getDimensions();
for (int i = 0; i < tmp_dims.nbDims; i++)
tmp_vec.push_back(tmp_dims.d[i]);
VLOG(3) << "updates tensor"
<< "'s dimension is :[" << string::join_strings(tmp_vec, ',')
<< "]";
}

const auto decrease_axes = PADDLE_GET_CONST(
std::vector<int64_t>, op_desc.GetAttr("decrease_axes"));
std::vector<int32_t> decr_axes{decrease_axes.begin(), decrease_axes.end()};
auto value_rank = updates->getDimensions().nbDims;
auto input_rank = inputs->getDimensions().nbDims;
if (!decrease_axes.empty() && value_rank != input_rank) {
// GLOG_vmodule=op_teller=6
VLOG(3) << "decrease_axes is: [" << string::join_strings(decrease_axes, ',')
<< "]";

if (decrease_axes.size() > 0 && value_rank != input_rank) {
updates = Unsqueeze(updates, decr_axes);
}

PADDLE_ENFORCE_EQ(
updates->getDimensions().nbDims,
input_rank,
platform::errors::InvalidArgument(
"ValueTensor‘s rank not equal to Input's rank, "
"you should try use C++ API "
"config.exp_disable_tensorrt_ops({\"%s\"}) to forbind this op "
"enter into TRT, "
"please find the %s's real name from .pdmodel or shape.txt",
output_name,
output_name));

// for log
{
auto tmp_dims = updates->getDimensions();
std::vector<int> tmp_vec;
tmp_vec.clear();
tmp_dims = updates->getDimensions();
for (int i = 0; i < tmp_dims.nbDims; i++)
tmp_vec.push_back(tmp_dims.d[i]);
VLOG(3) << "updates tensor"
<< "'s dimension is :[" << string::join_strings(tmp_vec, ',')
<< "]";
}

int64_t axes = 0;
int64_t starts = 0;
int64_t steps = 1;
int64_t ends = 0;

GET_ATTR_FROM_VECTOR(axes);
GET_ATTR_FROM_VECTOR(starts);
GET_ATTR_FROM_VECTOR(steps);
GET_ATTR_FROM_VECTOR(ends);

VLOG(3) << "axes is: " << axes;
VLOG(3) << "starts is: " << starts;
VLOG(3) << "steps is: " << steps;
VLOG(3) << "ends is: " << ends;

// calculate dims
auto input_dims = inputs->getDimensions();
auto update_dims = updates->getDimensions();

PADDLE_ENFORCE_GT(
input_dims.d[axes],
0,
platform::errors::InvalidArgument(
"the input_dims.d[%d] must be greater than 0, but received %d",
axes,
input_dims.d[axes]));

PADDLE_ENFORCE_GT(
update_dims.d[axes],
0,
platform::errors::InvalidArgument(
"the update_dims.d[%d] must be greater than 0, but received %d",
axes,
update_dims.d[axes]));

// check params and refill
if (axes == -1) {
axes = input_dims.nbDims - 1;
if (axes < 0) {
axes += input_dims.nbDims;
}

if (ends == -1 || ends > input_dims.d[axes]) {
if (ends < 0) {
ends += input_dims.d[axes];
}
if (ends >= input_dims.d[axes]) {
ends = input_dims.d[axes];
}

if (axes >= input_dims.nbDims) {
platform::errors::InvalidArgument(
"The axes %d is larger than total axes %d", axes, input_dims.nbDims);
}
if (starts >= input_dims.d[axes]) {
platform::errors::InvalidArgument(
"The start %d of dim %d is larger than origin shape %d",
starts,
axes,
input_dims.d[axes]);
}
if (update_dims.d[axes] != (input_dims.d[axes] - starts) / steps) {
platform::errors::InvalidArgument("The update dim error, should be %d",
(input_dims.d[axes] - starts) / steps);
}
VLOG(3) << "after standardization" << axes;
VLOG(3) << "axes is: " << axes;
VLOG(3) << "starts is: " << starts;
VLOG(3) << "steps is: " << steps;
VLOG(3) << "ends is: " << ends;

PADDLE_ENFORCE_LE(axes,
input_dims.nbDims,
platform::errors::InvalidArgument(
"The axes %d is larger than total axes %d",
axes,
input_dims.nbDims));

PADDLE_ENFORCE_LE(
starts,
input_dims.d[axes],
platform::errors::InvalidArgument(
"The start %d of dim %d is larger than origin shape %d",
starts,
axes,
input_dims.d[axes]));

PADDLE_ENFORCE_EQ(
update_dims.d[axes],
(ends - 1 - starts) / steps + 1,
platform::errors::InvalidArgument(
"the %dth axis of update dim error, should be %d, but we got %d",
axes,
(ends - 1 - starts) / steps + 1,
update_dims.d[axes]));

if (engine_->with_dynamic_shape()) {
// generate indice
int post_size = 1;
for (int j = axes + 1; j < update_dims.nbDims; ++j) {
post_size = post_size * update_dims.d[j];
}
std::vector<int> axes_index;
for (int i = starts; i < ends; i += steps) {
for (int j = 0; j < post_size; ++j) {
axes_index.emplace_back(i);
}
nvinfer1::Dims shape_0;
shape_0.nbDims = update_dims.nbDims;
for (int i = 0; i < shape_0.nbDims; ++i) {
shape_0.d[i] = 1;
}
int pre_size = 1;
for (int i = 0; i < axes; ++i) {
pre_size *= update_dims.d[i];
std::vector<float> tmp_0(1, 0);
auto zero_tensor = AddConstantLayer(tmp_0.data(), shape_0);
auto indice_tensor = Prod(zero_tensor, updates);
auto cast_layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *indice_tensor);
cast_layer->setOutputType(0, nvinfer1::DataType::kINT32);
indice_tensor = cast_layer->getOutput(0);

nvinfer1::Dims shape_1;
shape_1.nbDims = update_dims.nbDims;
for (int i = 0; i < update_dims.nbDims; ++i) {
shape_1.d[i] = 1;
}
std::vector<int> indices;
for (int i = 0; i < pre_size; ++i) {
indices.insert(indices.end(), axes_index.begin(), axes_index.end());
shape_1.d[axes] = update_dims.d[axes];
std::vector<int> tmp_1;
for (int i = starts; i < ends; i += steps) {
tmp_1.push_back(i);
}

auto output_name = op_desc.Output("Out")[0];
const auto const_layer = AddConstantLayer(
indices.data(), update_dims, "set_value_index_" + output_name);
auto one_tensor = AddConstantLayer(tmp_1.data(), shape_1);
indice_tensor = Sum(indice_tensor, one_tensor);

auto* layer = TRT_ENGINE_ADD_LAYER(engine_,
Scatter,
*inputs,
*const_layer,
*indice_tensor,
*updates,
nvinfer1::ScatterMode::kELEMENT);

Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2470,6 +2470,16 @@ struct SimpleOpTypeSetTeller : public Teller {
"starts or steps)";
return false;
}
if (desc.HasAttr("axes")) {
auto axes =
PADDLE_GET_CONST(std::vector<int64_t>, desc.GetAttr("axes"));
if (axes.size() != 1UL) {
VLOG(3) << "the set_value op"
<< "has more than one element in attribute axes, it can not "
"enter into trt.";
return false;
}
}
}

if (op_type == "top_k_v2" || op_type == "top_k") {
Expand Down
Loading

0 comments on commit 1c31c9a

Please sign in to comment.