Skip to content

Commit

Permalink
Merge branch 'develop' into apply_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao authored Dec 22, 2023
2 parents e5b8f61 + 0942bcf commit 0c91d95
Show file tree
Hide file tree
Showing 314 changed files with 11,583 additions and 4,954 deletions.
8 changes: 8 additions & 0 deletions paddle/cinn/hlir/framework/op_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ using StrategyFunction = std::function<std::shared_ptr<OpStrategy>(
const std::vector<Type>&,
const std::vector<std::vector<int>>&,
const cinn::common::Target&)>;

using StrategyFunctionSymbolic = std::function<std::shared_ptr<OpStrategy>(
const NodeAttr&,
const std::vector<ir::Tensor>&,
const std::vector<Type>&,
const std::vector<std::vector<ir::Dim>>&,
const cinn::common::Target&)>;

using InferShapeFunction = std::function<std::vector<std::vector<int>>(
const std::vector<std::vector<int>>&, const AttrMapType&)>;

Expand Down
73 changes: 60 additions & 13 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ void OpLowererImpl::LowerOpsForMapExpr(
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;

CollectOutputInfo(op, &out_types, &out_shapes);
CollectOutputInfo(op, &out_types, &out_shapes, group);
VLOG(4) << "out_types.size(): " << out_types.size();
NodeAttr node_attrs = details::CollectAttrs(*op);

Expand Down Expand Up @@ -364,7 +364,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::LowerCustomCall(

std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
CollectOutputInfo(op, &out_types, &out_shapes);
CollectOutputInfo(op, &out_types, &out_shapes, group);
VLOG(4) << "out_types.size(): " << out_types.size();

NodeAttr node_attrs = details::CollectAttrs(*op);
Expand Down Expand Up @@ -484,6 +484,8 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
ir::_Var_::Make(symbol_name, cinn::common::Int(32)));
group->int_args_map[non_tensor_arg_idx++] = {tensor_arg_idx,
tensor_arg_dim_idx};
VLOG(4) << "device kernel func's " << non_tensor_arg_idx << " is from "
<< tensor_arg_idx << ".shape(" << tensor_arg_dim_idx << ")";
}
}
}
Expand Down Expand Up @@ -519,20 +521,38 @@ std::vector<ir::Expr> OpLowererImpl::LowerOps(
std::vector<Expr> func_bodies;
for (auto* op : ops) {
// 1.Select Op impl
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
CollectOutputInfo(op, &out_types, &out_shapes);
VLOG(4) << "out_types.size(): " << out_types.size();
NodeAttr node_attrs = details::CollectAttrs(*op);

std::vector<ir::Tensor> op_func_arg_tensors =
CollectInputTensor(group, op, group_func_arg_tensors, tensor_map);
VLOG(4) << "input size:" << op_func_arg_tensors.size();

std::string cinn_op_name = CompatibleInfo::OpName(*op);
const hlir::framework::Operator* cinn_op = Operator::Get(cinn_op_name);
auto op_impl = OpStrategy::SelectImpl(strategy[cinn_op](
node_attrs, op_func_arg_tensors, out_types, out_shapes, this->target_));
std::shared_ptr<OpImpl> op_impl = nullptr;
if (FLAGS_cinn_bucket_compile) {
std::vector<Type> out_types;
std::vector<std::vector<ir::Dim>> out_shapes;
CollectOutputInfo(op, &out_types, &out_shapes, group);
VLOG(4) << "out_types.size(): " << out_types.size();
NodeAttr node_attrs = details::CollectAttrs(*op);
auto& strategy =
Operator::GetAttrs<StrategyFunctionSymbolic>("CINNStrategySymbolic");
op_impl = OpStrategy::SelectImpl(strategy[cinn_op](node_attrs,
op_func_arg_tensors,
out_types,
out_shapes,
this->target_));
} else {
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
CollectOutputInfo(op, &out_types, &out_shapes, group);
VLOG(4) << "out_types.size(): " << out_types.size();
NodeAttr node_attrs = details::CollectAttrs(*op);
op_impl = OpStrategy::SelectImpl(strategy[cinn_op](node_attrs,
op_func_arg_tensors,
out_types,
out_shapes,
this->target_));
}
// 2.Perform the lower process of Op
std::vector<ir::LoweredFunc> funcs = DoOpLower(
op_impl, op, tensor_map, tmp_tensor_info, &op_func_arg_tensors);
Expand Down Expand Up @@ -921,10 +941,28 @@ std::vector<ir::Tensor> OpLowererImpl::CollectInputTensor(
return tensors;
}

void OpLowererImpl::CollectOutputInfo(::pir::Operation* op,
std::vector<Type>* out_types,
std::vector<std::vector<int>>* out_shapes,
const GroupPtr& group) {
auto op_results = op->results();
for (auto& out_value : op_results) {
std::string output_id = ValueName(out_value);

auto type_info =
out_value.type().dyn_cast<paddle::dialect::DenseTensorType>();

out_types->push_back(CompatibleInfo::ConvertIRType(type_info.dtype()));
auto out_shape = ::common::vectorize<int>(type_info.dims());
out_shapes->push_back(std::move(out_shape));
}
}

void OpLowererImpl::CollectOutputInfo(
::pir::Operation* op,
std::vector<Type>* out_types,
std::vector<std::vector<int>>* out_shapes) {
std::vector<std::vector<ir::Dim>>* out_shapes,
const GroupPtr& group) {
auto op_results = op->results();
for (auto& out_value : op_results) {
std::string output_id = ValueName(out_value);
Expand All @@ -933,8 +971,17 @@ void OpLowererImpl::CollectOutputInfo(
out_value.type().dyn_cast<paddle::dialect::DenseTensorType>();

out_types->push_back(CompatibleInfo::ConvertIRType(type_info.dtype()));
auto out_shape = ::common::vectorize<int>(type_info.dims());
out_shapes->push_back(std::move(out_shape));
if (group->shape_analysis != nullptr) {
auto sym_vec =
group->shape_analysis->GetOrCreateSymbolicDimsForRankedValue(
out_value);
std::vector<ir::Dim> sym_shape;
for (auto& sym : sym_vec) {
sym_shape.emplace_back(
ir::Dim(output_id + "_" + sym.GetSymName(), sym));
}
out_shapes->push_back(std::move(sym_shape));
}
}
}

Expand Down
8 changes: 7 additions & 1 deletion paddle/cinn/hlir/framework/pir/op_lowering_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,13 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {

void CollectOutputInfo(::pir::Operation* op,
std::vector<Type>* out_types,
std::vector<std::vector<int>>* out_shapes);
std::vector<std::vector<int>>* out_shapes,
const GroupPtr& group);

void CollectOutputInfo(::pir::Operation* op,
std::vector<Type>* out_types,
std::vector<std::vector<ir::Dim>>* out_shapes,
const GroupPtr& group);

std::string ValueName(::pir::Value value);

Expand Down
57 changes: 57 additions & 0 deletions paddle/cinn/hlir/op/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ using framework::StrategyFunction;
const Target &target) { \
return StrategyForBroadcast( \
attrs, inputs, out_type, output_shapes, target, #op_name__, pe::pe__); \
} \
std::shared_ptr<OpStrategy> StrategyFor##pe__##Symbolic( \
const framework::NodeAttr &attrs, \
const std::vector<ir::Tensor> &inputs, \
const std::vector<Type> &out_type, \
const std::vector<std::vector<ir::Dim>> &output_shapes, \
const Target &target) { \
return StrategyForBroadcastSymbolic( \
attrs, inputs, out_type, output_shapes, target, #op_name__, pe::pe__); \
}

std::shared_ptr<OpStrategy> StrategyForBroadcast(
Expand Down Expand Up @@ -95,6 +104,51 @@ std::shared_ptr<OpStrategy> StrategyForBroadcast(
1);
return strategy;
}
std::shared_ptr<OpStrategy> StrategyForBroadcastSymbolic(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target,
const std::string &op_name,
ir::Tensor (*pe_func)(const ir::Tensor &A,
const ir::Tensor &B,
const std::string &output_name,
const Expr &axis)) {
framework::CINNCompute binary_compute(
[=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of " << op_name
<< " compute is empty! Please check.";
CINNValuePack pack_args = args[0];
CHECK_GE(pack_args.size(), 2U)
<< "at least 2 input tensors for " << op_name << " compute";
CHECK_GE(pack_args.size(), 3U) << op_name << " 's input is not enough!";
CHECK(pack_args[2].is_string());
std::string tensor_name = pack_args[2].operator std::string();
Expr A_expr = pack_args[0];
Expr B_expr = pack_args[1];
CHECK(A_expr.as_tensor());
CHECK(B_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();
ir::Tensor B = B_expr.as_tensor_ref();
Expr axis;
bool trans_a;
for (auto &iter : attrs.attr_store) {
if (iter.first == "axis") {
axis = Expr(absl::get<int>(iter.second));
break;
}
}
auto out = pe_func(A, B, tensor_name, axis);
auto stages = CreateStages({A, B, out});
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(
binary_compute, lang::PackedFunc(), "strategy." + op_name + ".x86", 1);
return strategy;
}

std::vector<shape_t> InferShapeForBroadcast(
const std::vector<shape_t> &inputs_shape,
Expand Down Expand Up @@ -453,6 +507,9 @@ CINN_REGISTER_HELPER(broadcast_ops) {
.set_num_outputs(1) \
.set_attr<cinn::hlir::framework::StrategyFunction>( \
"CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>( \
"CINNStrategySymbolic", \
cinn::hlir::op::StrategyFor##op_stragegy__##Symbolic) \
.set_attr("infershape", \
MakeOpFunction(cinn::hlir::op::InferShapeForBroadcast)) \
.set_attr("inferdtype", \
Expand Down
96 changes: 96 additions & 0 deletions paddle/cinn/hlir/op/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ using PeFunc = std::function<std::vector<ir::Tensor>(
const Target &target) { \
return StrategyForElementwise( \
attrs, inputs, out_type, output_shapes, target, #op_name__, pe::pe__); \
} \
std::shared_ptr<OpStrategy> StrategyFor##pe__##Symbolic( \
const framework::NodeAttr &attrs, \
const std::vector<ir::Tensor> &inputs, \
const std::vector<Type> &out_type, \
const std::vector<std::vector<ir::Dim>> &output_shapes, \
const Target &target) { \
return StrategyForElementwiseSymbolic( \
attrs, inputs, out_type, output_shapes, target, #op_name__, pe::pe__); \
}

std::shared_ptr<OpStrategy> StrategyForElementwise(
Expand Down Expand Up @@ -91,6 +100,44 @@ std::shared_ptr<OpStrategy> StrategyForElementwise(

return strategy;
}
std::shared_ptr<OpStrategy> StrategyForElementwiseSymbolic(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target,
const std::string &op_name,
const PeFunc &pe_func) {
framework::CINNCompute unary_compute(
[=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of " << op_name
<< " compute is empty! Please check.";
CINNValuePack pack_args = args[0];
CHECK_GE(pack_args.size(), 1U)
<< "1 input tensor for " << op_name << " compute";
CHECK_EQ(pack_args.size(), 2U);
CHECK(pack_args[1].is_string());
std::string tensor_name = pack_args[1].operator std::string();
Expr A_expr = pack_args[0];
CHECK(A_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();
auto out = pe_func(A, tensor_name);
auto stages = CreateStages({A});
std::vector<CINNValue> res;
for (auto &t : out) {
stages->InsertLazily(t);
res.push_back(CINNValue(t));
}
res.push_back(CINNValue(stages));
*ret = CINNValuePack{res};
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(
unary_compute, lang::PackedFunc(), "strategy." + op_name + ".x86", 1);

return strategy;
}

std::vector<shape_t> InferShapeForElementwise(
const std::vector<shape_t> &inputs_shape,
Expand Down Expand Up @@ -830,6 +877,50 @@ std::shared_ptr<OpStrategy> StrategyForReshape(
return strategy;
}

std::shared_ptr<OpStrategy> StrategyForReshapeSymbolic(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target) {
framework::CINNCompute reshape_compute([=](lang::Args args,
lang::RetValue *ret) {
CHECK(!args.empty())
<< "The input arguments of Reshape compute is empty! Please check.\n";
CINNValuePack pack_args = args[0];
CHECK_GE(pack_args.size(), 1U)
<< "at least 1 input tensors for Reshape compute\n";
Expr A = pack_args[0];
CHECK(A.as_tensor());
CHECK(!output_shapes.empty());
auto attr_store = attrs.attr_store;
CHECK(attr_store.count("shape")) << "find no attr of shape";
auto tensor_A = A.as_tensor_ref();
auto stages = CreateStages({tensor_A});
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");

CHECK_EQ(pack_args.size(), 2);
CHECK(pack_args[1].is_string());
std::string tensor_name = pack_args[1].operator std::string();

ir::Tensor out = pe::Reshape(tensor_A, output_shapes[0], tensor_name);
std::vector<CINNValue> res;
stages->InsertLazily(out);
res.push_back(CINNValue(out));
CHECK(!out_type.empty())
<< "Output type of Reshape is empty! Please check.\n";
res.push_back(CINNValue(stages));

*ret = CINNValuePack{res};
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(
reshape_compute, lang::PackedFunc(), "strategy.reshape.x86", 1);
return strategy;
}

std::vector<std::vector<int>> InferShapeForReshape(
const std::vector<std::vector<int>> &inputs_shape,
const framework::AttrMapType &attrs) {
Expand Down Expand Up @@ -1006,6 +1097,9 @@ CINN_REGISTER_HELPER(elementwise_ops) {
.set_num_outputs(1) \
.set_attr<cinn::hlir::framework::StrategyFunction>( \
"CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>( \
"CINNStrategySymbolic", \
cinn::hlir::op::StrategyFor##op_stragegy__##Symbolic) \
.set_attr("infershape", \
MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) \
.set_attr("inferdtype", \
Expand Down Expand Up @@ -1203,6 +1297,8 @@ CINN_REGISTER_HELPER(elementwise_ops) {
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>(
"CINNStrategy", cinn::hlir::op::StrategyForReshape)
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>(
"CINNStrategySymbolic", cinn::hlir::op::StrategyForReshapeSymbolic)
.set_attr("infershape",
MakeOpFunction(cinn::hlir::op::InferShapeForReshape))
.set_attr("inferdtype",
Expand Down
Loading

0 comments on commit 0c91d95

Please sign in to comment.