13
13
// limitations under the License.
14
14
15
15
#pragma once
16
+ #include < absl/types/variant.h>
16
17
#include < memory>
17
18
#include < unordered_map>
18
19
#include " paddle/cinn/common/context.h"
@@ -30,9 +31,15 @@ namespace cinn {
30
31
namespace hlir {
31
32
namespace framework {
32
33
33
- // TODO(Aurelius): Need add name mapping logic in REGISTER_CINN_OP
34
- // macros or attempt to unify Op name with Paddle and CINN.
35
- static const std::unordered_map<std::string, std::string> OP_NAMES = {
34
+ struct CompatibleInfo {
35
+ static constexpr char * kInputPrefix = " input_" ;
36
+ static constexpr char * kOutputPrefix = " output_" ;
37
+ // TODO(Aurelius): Need add name mapping logic in REGISTER_CINN_OP
38
+ // macros or attempt to unify Op name with Paddle and CINN.
39
+ static const std::unordered_map<std::string, std::string> OP_NAMES;
40
+ };
41
+
42
+ const std::unordered_map<std::string, std::string> CompatibleInfo::OP_NAMES = {
36
43
{" pd.full" , " fill_constant" }, {" pd.matmul" , " matmul" }};
37
44
38
45
// TODO(Aurelius84): Need abstract this logic to implement Proxy for
@@ -70,18 +77,32 @@ class NewIRCompiler final {
70
77
compiler_->Build (build_module, " " );
71
78
72
79
auto instructions = BuildInstructions (groups);
80
+
81
+ // TODO(Aurelius84): Instantiate all tensors on compile-time, which is
82
+ // controlled by 'options.with_instantiate_variables' in GraphCompiler.
83
+ // Moreover, it's better to implement InsertBufferHandlers() logic
84
+ // to automatically insert Malloc and Free instructions.
85
+ for (auto & name : scope_->var_names ()) {
86
+ std::string var_name ({name.data (), name.size ()});
87
+ VLOG (4 ) << " Instantiate " << var_name << " on compile-time" ;
88
+ auto * var = scope_->Var <Tensor>(var_name);
89
+ auto & tensor = absl::get<Tensor>(*var);
90
+ tensor->mutable_data (target_, tensor->type ());
91
+ }
73
92
return std::make_unique<Program>(scope_, std::move (instructions));
74
93
}
75
94
76
95
std::vector<ir::LoweredFunc> GetOpFunc (const ::ir::Operation& op, int idx) {
77
96
std::vector<ir::Tensor> inputs;
78
97
std::vector<common::CINNValue> cinn_inputs;
79
- VLOG (4 ) << " GetOpFunc for op: " << op.name ();
98
+ auto op_name = op.name ();
99
+ VLOG (4 ) << " GetOpFunc for op: " << op_name;
80
100
// step 1: Deal with Oprands
81
101
for (int i = 0 ; i < op.num_operands (); ++i) {
82
102
auto in_value = op.operand (i);
83
103
// TODO(Aurelius84): For now, use addr as name but it's not wise.
84
- std::string input_id = std::to_string (std::hash<::ir::Value>()(in_value));
104
+ std::string input_id = CompatibleInfo::kInputPrefix +
105
+ std::to_string (std::hash<::ir::Value>()(in_value));
85
106
// NOTE(Aurelius84): whether need to support other Type?
86
107
auto type_info =
87
108
in_value.type ().dyn_cast <paddle::dialect::DenseTensorType>();
@@ -100,8 +121,7 @@ class NewIRCompiler final {
100
121
cinn_inputs.push_back (common::CINNValue (temp));
101
122
}
102
123
for (auto out_name : OpGetOutputNames (op)) {
103
- cinn_inputs.push_back (
104
- common::CINNValue (op.name ().substr (3 ) + " _" + out_name));
124
+ cinn_inputs.push_back (common::CINNValue (out_name));
105
125
}
106
126
107
127
VLOG (4 ) << " inputs.size(): " << inputs.size ();
@@ -124,14 +144,14 @@ class NewIRCompiler final {
124
144
{
125
145
VLOG (4 ) << " op.attributes():" << op.attributes ().size ();
126
146
auto attrs = utils::ConvertAttributes (op.attributes ());
127
- node_attrs.node_name = OP_NAMES.at (op. name () );
147
+ node_attrs.node_name = CompatibleInfo:: OP_NAMES.at (op_name );
128
148
node_attrs.attr_store = std::move (attrs);
129
149
}
130
150
auto & strategy = Operator::GetAttrs<StrategyFunction>(" CINNStrategy" );
131
151
// NOTE(Aurelius84): Do we need replace all hlir::framework Operator with
132
152
// ::ir::Program ?
133
153
const hlir::framework::Operator* cinn_op =
134
- Operator::Get (OP_NAMES.at (op. name () ));
154
+ Operator::Get (CompatibleInfo:: OP_NAMES.at (op_name ));
135
155
auto impl = OpStrategy::SelectImpl (
136
156
strategy[cinn_op](node_attrs, inputs, out_types, out_shapes, target_));
137
157
common::CINNValuePack C =
@@ -223,7 +243,8 @@ class NewIRCompiler final {
223
243
std::unordered_set<std::string> repeat;
224
244
for (int i = 0 ; i < op.num_operands (); ++i) {
225
245
auto value = op.operand (i);
226
- std::string name = std::to_string (std::hash<::ir::Value>()(value));
246
+ std::string name = CompatibleInfo::kInputPrefix +
247
+ std::to_string (std::hash<::ir::Value>()(value));
227
248
if (repeat.count (name)) {
228
249
continue ;
229
250
}
@@ -237,7 +258,8 @@ class NewIRCompiler final {
237
258
std::vector<std::string> names;
238
259
for (int i = 0 ; i < op.num_results (); ++i) {
239
260
auto value = op.result (i);
240
- std::string name = std::to_string (std::hash<::ir::Value>()(value));
261
+ std::string name = CompatibleInfo::kOutputPrefix +
262
+ std::to_string (std::hash<::ir::Value>()(value));
241
263
names.push_back (std::move (name));
242
264
}
243
265
return names;
@@ -257,11 +279,12 @@ std::shared_ptr<Scope> BuildScope(const Target& target,
257
279
std::unordered_set<::ir::Value> visited;
258
280
auto scope = std::make_shared<Scope>();
259
281
260
- auto create_var = [&](::ir::Value value) {
282
+ auto create_var = [&](const std::string& name_prefix, ::ir::Value value) {
261
283
if (visited.count (value) > 0 ) return ;
262
284
visited.emplace (value);
263
285
264
- std::string name = std::to_string (std::hash<::ir::Value>()(value));
286
+ std::string name =
287
+ name_prefix + std::to_string (std::hash<::ir::Value>()(value));
265
288
auto type_info = value.type ().dyn_cast <paddle::dialect::DenseTensorType>();
266
289
auto * var = scope->Var <Tensor>(name);
267
290
auto & tensor = absl::get<Tensor>(*var);
@@ -279,12 +302,12 @@ std::shared_ptr<Scope> BuildScope(const Target& target,
279
302
// visit OpOprands
280
303
for (auto i = 0 ; i < (*it)->num_operands (); ++i) {
281
304
auto in_value = (*it)->operand (i);
282
- create_var (in_value);
305
+ create_var (CompatibleInfo:: kInputPrefix , in_value);
283
306
}
284
307
285
308
for (auto i = 0 ; i < (*it)->num_results (); ++i) {
286
309
auto out_value = (*it)->result (i);
287
- create_var (out_value);
310
+ create_var (CompatibleInfo:: kOutputPrefix , out_value);
288
311
}
289
312
}
290
313
return scope;
0 commit comments