@@ -20,7 +20,15 @@ limitations under the License. */
20
20
if (op_desc.HasAttr (#attr_name__)) { \
21
21
vec_##attr_name__ = PADDLE_GET_CONST (std::vector<int64_t >, \
22
22
op_desc.GetAttr (#attr_name__)); \
23
- if (!vec_##attr_name__.empty ()) attr_name__ = vec_##attr_name__[0 ]; \
23
+ if (vec_##attr_name__.size () > 0 ) { \
24
+ attr_name__ = vec_##attr_name__[0 ]; \
25
+ PADDLE_ENFORCE_EQ (vec_##attr_name__.size (), \
26
+ 1UL , \
27
+ platform::errors::InvalidArgument ( \
28
+ " attr axes/starst/ends/steps 's size in " \
29
+ " set_value must be one, but got %d" , \
30
+ vec_##attr_name__.size ())); \
31
+ } \
24
32
} \
25
33
} while (0 )
26
34
@@ -42,85 +50,200 @@ class SetValueConverter : public OpConverter {
42
50
bool test_mode) override {
43
51
VLOG (3 ) << " convert a set value op to tensorrt" ;
44
52
framework::OpDesc op_desc (op, nullptr );
45
-
46
- auto * inputs = engine_->GetITensor (op_desc.Input (" Input" )[0 ]);
47
- auto * updates = engine_->GetITensor (op_desc.Input (" ValueTensor" )[0 ]);
48
- const auto decrease_axes = PADDLE_GET_CONST (
49
- std::vector<int64_t >, op_desc.GetAttr (" decrease_axes" ));
50
- std::vector<int32_t > decr_axes{decrease_axes.begin (), decrease_axes.end ()};
51
- auto value_rank = updates->getDimensions ().nbDims ;
52
- auto input_rank = inputs->getDimensions ().nbDims ;
53
- if (!decrease_axes.empty () && value_rank != input_rank) {
54
- updates = Unsqueeze (updates, decr_axes);
55
- }
56
-
57
53
int64_t axes = 0 ;
58
54
int64_t starts = 0 ;
59
55
int64_t steps = 1 ;
60
56
int64_t ends = 0 ;
61
-
62
57
GET_ATTR_FROM_VECTOR (axes);
63
58
GET_ATTR_FROM_VECTOR (starts);
64
59
GET_ATTR_FROM_VECTOR (steps);
65
60
GET_ATTR_FROM_VECTOR (ends);
66
61
67
- // calculate dims
62
+ VLOG (3 ) << " axes is: " << axes;
63
+ VLOG (3 ) << " starts is: " << starts;
64
+ VLOG (3 ) << " steps is: " << steps;
65
+ VLOG (3 ) << " ends is: " << ends;
66
+
67
+ auto * inputs = engine_->GetITensor (op_desc.Input (" Input" )[0 ]);
68
+
68
69
auto input_dims = inputs->getDimensions ();
69
- auto update_dims = updates->getDimensions ();
70
70
71
71
// check params and refill
72
- if (axes == - 1 ) {
73
- axes = input_dims.nbDims - 1 ;
72
+ if (axes < 0 ) {
73
+ axes + = input_dims.nbDims ;
74
74
}
75
75
76
- if (ends == -1 || ends > input_dims.d [axes]) {
76
+ if (ends < 0 ) {
77
+ ends += input_dims.d [axes];
78
+ }
79
+ if (ends >= input_dims.d [axes]) {
77
80
ends = input_dims.d [axes];
78
81
}
79
82
80
- if (axes >= input_dims.nbDims ) {
81
- platform::errors::InvalidArgument (
82
- " The axes %d is larger than total axes %d" , axes, input_dims.nbDims );
83
+ VLOG (3 ) << " after standardization" << axes;
84
+ VLOG (3 ) << " axes is: " << axes;
85
+ VLOG (3 ) << " starts is: " << starts;
86
+ VLOG (3 ) << " steps is: " << steps;
87
+ VLOG (3 ) << " ends is: " << ends;
88
+
89
+ auto output_name = op_desc.Output (" Out" )[0 ];
90
+ nvinfer1::ITensor* updates;
91
+ if (op_desc.HasInput (" ValueTensor" ) &&
92
+ op_desc.Input (" ValueTensor" ).size () > 0 ) {
93
+ updates = engine_->GetITensor (op_desc.Input (" ValueTensor" )[0 ]);
94
+ } else {
95
+ int dtype = PADDLE_GET_CONST (int , op_desc.GetAttr (" dtype" ));
96
+ PADDLE_ENFORCE_EQ (dtype,
97
+ 5 ,
98
+ platform::errors::InvalidArgument (
99
+ " set_value OP dtype must be float" ));
100
+ float value = PADDLE_GET_CONST (std::vector<paddle::experimental::Scalar>,
101
+ op_desc.GetAttr (" values" ))[0 ]
102
+ .to <float >();
103
+ VLOG (3 ) << " the attribute value is: " << value;
104
+
105
+ nvinfer1::ITensor* input_shape_tensor = Shape (inputs);
106
+ std::vector<nvinfer1::ITensor*> vec_tensor;
107
+ for (int32_t i = 0 ; i < input_dims.nbDims ; ++i) {
108
+ vec_tensor.push_back (GetEleTensorOfShape (input_shape_tensor, i));
109
+ }
110
+ std::vector<int32_t > axes_vec (1 , (ends - 1 - starts) / steps + 1 );
111
+ vec_tensor[axes] = Add1DConstantLayer (axes_vec, " axes_vec" , false );
112
+ nvinfer1::ITensor* output_shape_tensor = Concat (vec_tensor, 0 );
113
+ updates = FillConstantLayer (
114
+ output_shape_tensor, inputs->getDimensions ().nbDims , value);
115
+ }
116
+
117
+ // for log
118
+ {
119
+ std::vector<int > tmp_vec;
120
+ for (int i = 0 ; i < input_dims.nbDims ; i++)
121
+ tmp_vec.push_back (input_dims.d [i]);
122
+ VLOG (3 ) << " Input(Name:" << op_desc.Input (" Input" )[0 ] << " )"
123
+ << " 's dimension is :[" << string::join_strings (tmp_vec, ' ,' )
124
+ << " ]" ;
125
+
126
+ tmp_vec.clear ();
127
+ nvinfer1::Dims tmp_dims = updates->getDimensions ();
128
+ for (int i = 0 ; i < tmp_dims.nbDims ; i++)
129
+ tmp_vec.push_back (tmp_dims.d [i]);
130
+ VLOG (3 ) << " updates tensor"
131
+ << " 's dimension is :[" << string::join_strings (tmp_vec, ' ,' )
132
+ << " ]" ;
83
133
}
84
- if (starts >= input_dims.d [axes]) {
85
- platform::errors::InvalidArgument (
86
- " The start %d of dim %d is larger than origin shape %d" ,
87
- starts,
88
- axes,
89
- input_dims.d [axes]);
134
+
135
+ const auto decrease_axes = PADDLE_GET_CONST (
136
+ std::vector<int64_t >, op_desc.GetAttr (" decrease_axes" ));
137
+ std::vector<int32_t > decr_axes{decrease_axes.begin (), decrease_axes.end ()};
138
+ auto value_rank = updates->getDimensions ().nbDims ;
139
+ auto input_rank = inputs->getDimensions ().nbDims ;
140
+ // GLOG_vmodule=op_teller=6
141
+ VLOG (3 ) << " decrease_axes is: [" << string::join_strings (decrease_axes, ' ,' )
142
+ << " ]" ;
143
+
144
+ if (decrease_axes.size () > 0 && value_rank != input_rank) {
145
+ updates = Unsqueeze (updates, decr_axes);
90
146
}
91
- if (update_dims.d [axes] != (input_dims.d [axes] - starts) / steps) {
92
- platform::errors::InvalidArgument (" The update dim error, should be %d" ,
93
- (input_dims.d [axes] - starts) / steps);
147
+
148
+ PADDLE_ENFORCE_EQ (
149
+ updates->getDimensions ().nbDims ,
150
+ input_rank,
151
+ platform::errors::InvalidArgument (
152
+ " ValueTensor‘s rank not equal to Input's rank, "
153
+ " you should try use C++ API "
154
+ " config.exp_disable_tensorrt_ops({\" %s\" }) to forbind this op "
155
+ " enter into TRT, "
156
+ " please find the %s's real name from .pdmodel or shape.txt" ,
157
+ output_name,
158
+ output_name));
159
+
160
+ // for log
161
+ {
162
+ auto tmp_dims = updates->getDimensions ();
163
+ std::vector<int > tmp_vec;
164
+ tmp_vec.clear ();
165
+ tmp_dims = updates->getDimensions ();
166
+ for (int i = 0 ; i < tmp_dims.nbDims ; i++)
167
+ tmp_vec.push_back (tmp_dims.d [i]);
168
+ VLOG (3 ) << " updates tensor"
169
+ << " 's dimension is :[" << string::join_strings (tmp_vec, ' ,' )
170
+ << " ]" ;
94
171
}
172
+
173
+ // calculate dims
174
+ auto update_dims = updates->getDimensions ();
175
+
176
+ PADDLE_ENFORCE_GT (
177
+ input_dims.d [axes],
178
+ 0 ,
179
+ platform::errors::InvalidArgument (
180
+ " the input_dims.d[%d] must be greater than 0, but received %d" ,
181
+ axes,
182
+ input_dims.d [axes]));
183
+
184
+ PADDLE_ENFORCE_GT (
185
+ update_dims.d [axes],
186
+ 0 ,
187
+ platform::errors::InvalidArgument (
188
+ " the update_dims.d[%d] must be greater than 0, but received %d" ,
189
+ axes,
190
+ update_dims.d [axes]));
191
+
192
+ PADDLE_ENFORCE_LE (axes,
193
+ input_dims.nbDims ,
194
+ platform::errors::InvalidArgument (
195
+ " The axes %d is larger than total axes %d" ,
196
+ axes,
197
+ input_dims.nbDims ));
198
+
199
+ PADDLE_ENFORCE_LE (
200
+ starts,
201
+ input_dims.d [axes],
202
+ platform::errors::InvalidArgument (
203
+ " The start %d of dim %d is larger than origin shape %d" ,
204
+ starts,
205
+ axes,
206
+ input_dims.d [axes]));
207
+
208
+ PADDLE_ENFORCE_EQ (
209
+ update_dims.d [axes],
210
+ (ends - 1 - starts) / steps + 1 ,
211
+ platform::errors::InvalidArgument (
212
+ " the %dth axis of update dim error, should be %d, but we got %d" ,
213
+ axes,
214
+ (ends - 1 - starts) / steps + 1 ,
215
+ update_dims.d [axes]));
216
+
95
217
if (engine_->with_dynamic_shape ()) {
96
- // generate indice
97
- int post_size = 1 ;
98
- for (int j = axes + 1 ; j < update_dims.nbDims ; ++j) {
99
- post_size = post_size * update_dims.d [j];
100
- }
101
- std::vector<int > axes_index;
102
- for (int i = starts; i < ends; i += steps) {
103
- for (int j = 0 ; j < post_size; ++j) {
104
- axes_index.emplace_back (i);
105
- }
218
+ nvinfer1::Dims shape_0;
219
+ shape_0.nbDims = update_dims.nbDims ;
220
+ for (int i = 0 ; i < shape_0.nbDims ; ++i) {
221
+ shape_0.d [i] = 1 ;
106
222
}
107
- int pre_size = 1 ;
108
- for (int i = 0 ; i < axes; ++i) {
109
- pre_size *= update_dims.d [i];
223
+ std::vector<float > tmp_0 (1 , 0 );
224
+ auto zero_tensor = AddConstantLayer (tmp_0.data (), shape_0);
225
+ auto indice_tensor = Prod (zero_tensor, updates);
226
+ auto cast_layer = TRT_ENGINE_ADD_LAYER (engine_, Identity, *indice_tensor);
227
+ cast_layer->setOutputType (0 , nvinfer1::DataType::kINT32 );
228
+ indice_tensor = cast_layer->getOutput (0 );
229
+
230
+ nvinfer1::Dims shape_1;
231
+ shape_1.nbDims = update_dims.nbDims ;
232
+ for (int i = 0 ; i < update_dims.nbDims ; ++i) {
233
+ shape_1.d [i] = 1 ;
110
234
}
111
- std::vector<int > indices;
112
- for (int i = 0 ; i < pre_size; ++i) {
113
- indices.insert (indices.end (), axes_index.begin (), axes_index.end ());
235
+ shape_1.d [axes] = update_dims.d [axes];
236
+ std::vector<int > tmp_1;
237
+ for (int i = starts; i < ends; i += steps) {
238
+ tmp_1.push_back (i);
114
239
}
115
-
116
- auto output_name = op_desc.Output (" Out" )[0 ];
117
- const auto const_layer = AddConstantLayer (
118
- indices.data (), update_dims, " set_value_index_" + output_name);
240
+ auto one_tensor = AddConstantLayer (tmp_1.data (), shape_1);
241
+ indice_tensor = Sum (indice_tensor, one_tensor);
119
242
120
243
auto * layer = TRT_ENGINE_ADD_LAYER (engine_,
121
244
Scatter,
122
245
*inputs,
123
- *const_layer ,
246
+ *indice_tensor ,
124
247
*updates,
125
248
nvinfer1::ScatterMode::kELEMENT );
126
249
0 commit comments