Skip to content

Commit

Permalink
Merge branch 'PaddlePaddle:develop' into paddle_test_15
Browse files Browse the repository at this point in the history
  • Loading branch information
zade23 authored Jan 23, 2024
2 parents 72bd9f1 + f522f04 commit 557bae9
Show file tree
Hide file tree
Showing 54 changed files with 609 additions and 209 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ llvm::Value* CodeGenCUDA_Host::LowerHostFunc(const ir::_LoweredFunc_* func) {

llvm::Value* CodeGenCUDA_Host::LowerParseArgsValueCall(
const ir::Call* call_ir) {
auto ret_type = CinnTypeToLLVMType(Int(32), m_);
auto ret_type = CinnTypeToLLVMType(Int(64), m_);
std::vector<llvm::Type*> args_type;
CHECK_EQ(call_ir->read_args.size(), 2);
CHECK(call_ir->read_args[0].is_var() &&
Expand Down
7 changes: 4 additions & 3 deletions paddle/cinn/backends/codegen_cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,16 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(
for (int i = 0; i < args.size(); ++i) {
if (args[i].is_var()) {
ir::Expr call_get_value_in_kernel_args =
ir::Call::Make(Int(32),
ir::Call::Make(Int(64),
runtime::intrinsic::get_value_in_cuda_kernel_args,
{kernel_args_, ir::Expr(i)},
{},
ir::CallType::Extern,
ir::FunctionRef(),
0);
ir::Expr stmt = ir::Let::Make(ir::Expr(args[i].var_arg()),
call_get_value_in_kernel_args);
ir::Expr let_symbol = ir::Expr(args[i].var_arg());
let_symbol->set_type(type_of<int64_t>());
ir::Expr stmt = ir::Let::Make(let_symbol, call_get_value_in_kernel_args);
arg_defs_.push_back(stmt);
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ struct CollectBucketStrategyHostFunctionVisitor
kernel_args_(KERNEL_ARGS, type_of<void*>()),
kernel_args_num_(KERNEL_ARGS_NUM, type_of<int>()),
kernel_stream_(KERNEL_STREAM, type_of<void*>()),
tensor_shape_args_(TENSOR_SHAPE_ARGS, type_of<int32_t**>()) {}
tensor_shape_args_(TENSOR_SHAPE_ARGS, type_of<int64_t**>()) {}

std::tuple<ir::Module, ir::Module> operator()(Expr* expr) {
ir::IRMutator<>::Visit(expr, expr);
Expand Down
6 changes: 5 additions & 1 deletion paddle/cinn/common/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,12 @@ Expr IndiceToAbsOffset(const std::vector<Expr> &shape,
VLOG(3) << "indices is : " << utils::Join(indices, ",");
CHECK_LE(shape.size(), indices.size());
Expr res;
ir::TryElevateInt32ToInt64(shape);
for (int i = 0; i < shape.size(); i++) {
CHECK_EQ(shape[i].type(), Int(32));
CHECK(shape[i].type() == Int(64) || shape[i].type() == Int(32))
<< "The shape data type currently supports only int32 or int64, but "
"the current data type of shape["
<< i << "] is " << shape[i].type();
Expr indice_prod = indices[i];
optim::SimplifyCast(&indice_prod);
for (int j = i + 1; j < shape.size(); j++) {
Expand Down
6 changes: 6 additions & 0 deletions paddle/cinn/common/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ inline Type type_of<int32_t**>() {
return x;
}
template <>
inline Type type_of<int64_t**>() {
Type x = Int(64);
x.set_cpp_handle2();
return x;
}
template <>
inline Type type_of<void*>() {
Type x = type_of<void>();
x.set_cpp_handle();
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ ir::Expr OpLowererImpl::DoGroupSchedule(
auto master_loops = ir_sch.GetLoops(GetNodeData(master)->id());
std::vector<int> splits;
for (auto loop : master_loops) {
splits.push_back(loop.As<ir::For>()->extent.as_int32());
splits.push_back(loop.As<ir::For>()->extent.as_int64());
}
loops = ir_sch.GetLoops(GetNodeData(node)->id());
ir_sch.Split(loops[0], splits);
Expand Down
8 changes: 4 additions & 4 deletions paddle/cinn/hlir/framework/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1537,8 +1537,8 @@ void MergeReduceLoop(
auto dst_loops = ir_sch.GetLoops(tensor_->name);
auto src_loops = ir_sch.GetLoops(tensor__->name);
int index = -1;
while (src_loops[index + 1].As<ir::For>()->extent.as_int32() ==
dst_loops[index + 1].As<ir::For>()->extent.as_int32()) {
while (src_loops[index + 1].As<ir::For>()->extent.as_int64() ==
dst_loops[index + 1].As<ir::For>()->extent.as_int64()) {
++index;
if (src_loops.size() == index + 1 || dst_loops.size() == index + 1) {
break;
Expand Down Expand Up @@ -1661,8 +1661,8 @@ void LoopComputeAt(
int index = std::min(node_loops.size(), master_loops.size()) - 1;
do {
// if loop range is not equal.
if (node_loops[index].As<ir::For>()->extent.as_int32() !=
master_loops[index].As<ir::For>()->extent.as_int32()) {
if (node_loops[index].As<ir::For>()->extent.as_int64() !=
master_loops[index].As<ir::For>()->extent.as_int64()) {
continue;
}
MergeLoops(
Expand Down
46 changes: 31 additions & 15 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
}
int_args_set.insert(symbol_name);
group_func_args->emplace_back(
ir::_Var_::Make(symbol_name, cinn::common::Int(32)));
ir::_Var_::Make(symbol_name, cinn::common::Int(64)));
group->int_args_map[non_tensor_arg_idx++] = {tensor_arg_idx,
tensor_arg_dim_idx};
VLOG(4) << "device kernel func's " << non_tensor_arg_idx << " is from "
Expand Down Expand Up @@ -550,15 +550,20 @@ std::vector<ir::Expr> OpLowererImpl::LowerOps(
std::vector<Type> out_types;
std::vector<std::vector<ir::Dim>> out_shapes;
CollectOutputInfo(op, &out_types, &out_shapes, group);
CHECK_EQ(out_types.size(), out_shapes.size());
VLOG(4) << "out_types.size(): " << out_types.size();
NodeAttr node_attrs = details::CollectAttrs(*op);
auto& strategy =
auto& strategy_map =
Operator::GetAttrs<StrategyFunctionSymbolic>("CINNStrategySymbolic");
op_impl = OpStrategy::SelectImpl(strategy[cinn_op](node_attrs,
op_func_arg_tensors,
out_types,
out_shapes,
this->target_));
StrategyFunctionSymbolic strategy = strategy_map[cinn_op];
CHECK(static_cast<bool>(strategy))
<< " cinn_op_name: " << cinn_op_name
<< "has no CINNStrategySymbolic registered.";
op_impl = OpStrategy::SelectImpl(strategy(node_attrs,
op_func_arg_tensors,
out_types,
out_shapes,
this->target_));
} else {
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
Expand Down Expand Up @@ -797,14 +802,25 @@ void OpLowererImpl::CollectOutputInfo(
out_value.type().dyn_cast<paddle::dialect::DenseTensorType>();

out_types->push_back(CompatibleInfo::ConvertIRType(type_info.dtype()));
if (!group->value_to_shape_or_data_exprs.empty()) {
auto sym_vec = group->GetShapeOrDataExprs(out_value).shape();
std::vector<ir::Dim> sym_shape;
for (auto& sym : sym_vec) {
sym_shape.emplace_back(output_id, sym);

auto ForEachDimExpr = [&](const auto& DoEach) {
if (!group->value_to_shape_or_data_exprs.empty()) {
auto sym_vec = group->GetShapeOrDataExprs(out_value).shape();
std::vector<ir::Dim> sym_shape;
for (const auto& sym : sym_vec) {
DoEach(sym);
}
} else {
auto out_shape = ::common::vectorize<int64_t>(type_info.dims());
for (int64_t dim : out_shape) {
DoEach(symbol::DimExpr{dim});
}
}
out_shapes->push_back(std::move(sym_shape));
}
};
std::vector<ir::Dim> sym_shape;
ForEachDimExpr(
[&](const auto& sym) { sym_shape.emplace_back(output_id, sym); });
out_shapes->emplace_back(std::move(sym_shape));
}
}

Expand Down Expand Up @@ -860,7 +876,7 @@ ir::LoweredFunc OpLowererImpl::GenerateInferShapeFunc(
int tensor_dim_size = tensor_dim.size();
auto tensor_shape = group_func_arg_tensors[tensor_arg_idx]->shape;

ir::Var tensor_shape_args(TENSOR_SHAPE_ARGS, type_of<int32_t**>());
ir::Var tensor_shape_args(TENSOR_SHAPE_ARGS, type_of<int64_t**>());
for (int i = 0; i < tensor_shape.size(); i++) {
ir::Expr call_set_infer_shape_value =
ir::Call::Make(type_of<void>(),
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/hlir/op/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,14 @@ std::shared_ptr<OpStrategy> StrategyForBroadcastToSymbolic(
const std::vector<Type> &out_type,
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target) {
CHECK_EQ(output_shapes.size(), 1);
std::vector<ir::Expr> out_shape(output_shapes[0].size());
std::transform(output_shapes[0].begin(),
output_shapes[0].end(),
out_shape.begin(),
[](const ir::Dim &dim) { return dim->dim_expr; });
std::vector<int> broadcast_axes;
CHECK_GT(attrs.attr_store.count("broadcast_axes"), 0);
broadcast_axes =
absl::get<std::vector<int>>(attrs.attr_store.at("broadcast_axes"));
VLOG(3) << "broadcast out shape: " << utils::Join(out_shape, ", ");
Expand Down
126 changes: 126 additions & 0 deletions paddle/cinn/hlir/op/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,75 @@ std::shared_ptr<OpStrategy> StrategyForScale(
return strategy;
}

std::shared_ptr<OpStrategy> StrategyForScaleSymbolic(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target) {
float scale = 1.f;
float bias = 0.f;
bool bias_after_scale = true;
for (auto &iter : attrs.attr_store) {
if (iter.first == "scale") {
scale = absl::get<float>(iter.second);
} else if (iter.first == "bias") {
bias = absl::get<float>(iter.second);
} else if (iter.first == "bias_after_scale") {
bias_after_scale = absl::get<bool>(iter.second);
}
}
framework::CINNCompute scale_compute(
[=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty())
<< "The input arguments of scale compute is empty! Please check.";
CINNValuePack pack_args = args[0];
CHECK(!pack_args.empty())
<< "The input tensors of scale compute is empty! Please check.";
Expr A_expr = pack_args[0];
CHECK(A_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();
ir::Tensor out;
CHECK_EQ(pack_args.size(), 2);
CHECK(pack_args[1].is_string());
std::string tensor_name = pack_args[1].operator std::string();

// Paddle upscale float16 or bfloat16 compute to float32,
// we made CINN consistent with this behavior of Paddle
bool should_upscale_fp32 = A->type() == cinn::common::F16() ||
A->type() == cinn::common::BF16();

out = Compute(
A->shape,
[=](const std::vector<Expr> &indice) {
Expr cast_scale = should_upscale_fp32
? Expr(scale)
: ir::Cast::Make(A->type(), Expr(scale));
Expr cast_bias = should_upscale_fp32
? Expr(bias)
: ir::Cast::Make(A->type(), Expr(bias));
Expr cast_A_indice =
should_upscale_fp32
? ir::Cast::Make(cinn::common::F32(), A(indice))
: A(indice);
Expr add_result = bias_after_scale
? cast_scale * cast_A_indice + cast_bias
: cast_scale * (cast_A_indice + cast_bias);
return should_upscale_fp32 ? ir::Cast::Make(A->type(), add_result)
: add_result;
},
tensor_name);

auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(scale_compute, lang::PackedFunc(), "strategy.scale.x86", 1);

return strategy;
}

Expr GetScalarExpr(const framework::NodeAttr::attr_t &attr) {
Expr scalar;
struct Visitor {
Expand Down Expand Up @@ -450,6 +519,58 @@ std::shared_ptr<OpStrategy> StrategyForFillConstant(
return strategy;
}

std::shared_ptr<OpStrategy> StrategyForFillConstantSymbolic(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target) {
framework::CINNCompute fill_constant_compute(
[=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of fill_constant compute "
"is empty! Please check.";
bool force_cpu = false;
CHECK(attrs.attr_store.count("shape"));
auto shape = absl::get<std::vector<int>>(attrs.attr_store.at("shape"));
CHECK(attrs.attr_store.count("value"));
auto value = GetScalarExpr(attrs.attr_store.at("value"));
CHECK(attrs.attr_store.count("force_cpu"));
force_cpu = absl::get<bool>(attrs.attr_store.at("force_cpu"));

if (force_cpu && target != cinn::common::DefaultHostTarget()) {
LOG(WARNING) << "The attribute \"force_cpu\" of \"fill_constant\" "
"not supported in CINN! The \"fill_constant\"'s "
"output tensor will placed on "
<< target;
}

CINNValuePack arg_pack = args[0];
CHECK_EQ(arg_pack.size(), 1U);
CHECK(arg_pack[0].is_string());
std::string tensor_name = arg_pack[0].operator std::string();
CHECK(!shape.empty()) << "shape attr is empty!";
auto shape_exprs = ToCinnExprs(shape);
auto out = lang::Compute(
shape_exprs,
[=](const std::vector<Expr> &indice) {
return ir::Cast::Make(out_type[0], value);
},
tensor_name);
CHECK(out.defined())
<< "can't create fill_constant with the given type " << out_type[0];
auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}};
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(fill_constant_compute,
lang::PackedFunc(),
"strategy.fill_constant.x86",
1);

return strategy;
}

std::vector<shape_t> InferShapeForFillConstant(
const std::vector<shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
Expand Down Expand Up @@ -1178,6 +1299,8 @@ CINN_REGISTER_HELPER(elementwise_ops) {
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>(
"CINNStrategy", cinn::hlir::op::StrategyForScale)
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>(
"CINNStrategySymbolic", cinn::hlir::op::StrategyForScaleSymbolic)
.set_attr("infershape",
MakeOpFunction(cinn::hlir::op::InferShapeForElementwise))
.set_attr("inferdtype",
Expand Down Expand Up @@ -1226,6 +1349,9 @@ CINN_REGISTER_HELPER(elementwise_ops) {
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>(
"CINNStrategy", cinn::hlir::op::StrategyForFillConstant)
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>(
"CINNStrategySymbolic",
cinn::hlir::op::StrategyForFillConstantSymbolic)
.set_attr("infershape",
MakeOpFunction(cinn::hlir::op::InferShapeForFillConstant))
.set_attr("inferdtype",
Expand Down
10 changes: 7 additions & 3 deletions paddle/cinn/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,15 @@ Var _Buffer_::buffer_addr() const {
return _Var_::Make(name, thetype);
}

int _Buffer_::numel() const {
int res = 1;
int64_t _Buffer_::numel() const {
int64_t res = 1;
for (auto &i : shape) {
CHECK(i.is_constant());
res *= i.as_int32();
if (i->type() == Int(64)) {
res *= i.as_int64();
} else {
res *= i.as_int32();
}
}
return res;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/ir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class _Buffer_ : public ExprNode<_Buffer_> {

void Verify() const override;

int numel() const;
int64_t numel() const;

static const IrNodeTy _node_type_ = IrNodeTy::_Buffer_;

Expand Down
8 changes: 2 additions & 6 deletions paddle/cinn/ir/dim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/cinn/ir/dim.h"
#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/ir/ir.h"

namespace cinn {
Expand All @@ -31,12 +32,7 @@ Dim _Dim_::Make(const std::string& name, const symbol::DimExpr& sym_dim) {
auto* n = make_shared<_Dim_>();
n->name = name;
n->sym_dim = sym_dim;
if (sym_dim.isa<std::string>()) {
n->dim_expr =
Expr(Var(sym_dim.dyn_cast<std::string>(), cinn::common::Int(32)));
} else {
n->dim_expr = Expr(static_cast<int32_t>(sym_dim.dyn_cast<int64_t>()));
}
n->dim_expr = common::DimExprConverter().ConvertToIrExpr(sym_dim);

return Dim(n);
}
Expand Down
Loading

0 comments on commit 557bae9

Please sign in to comment.