Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt to dim expr #60843

Merged
merged 10 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ llvm::Value* CodeGenCUDA_Host::LowerHostFunc(const ir::_LoweredFunc_* func) {

llvm::Value* CodeGenCUDA_Host::LowerParseArgsValueCall(
const ir::Call* call_ir) {
auto ret_type = CinnTypeToLLVMType(Int(32), m_);
auto ret_type = CinnTypeToLLVMType(Int(64), m_);
std::vector<llvm::Type*> args_type;
CHECK_EQ(call_ir->read_args.size(), 2);
CHECK(call_ir->read_args[0].is_var() &&
Expand Down
7 changes: 4 additions & 3 deletions paddle/cinn/backends/codegen_cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,16 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(
for (int i = 0; i < args.size(); ++i) {
if (args[i].is_var()) {
ir::Expr call_get_value_in_kernel_args =
ir::Call::Make(Int(32),
ir::Call::Make(Int(64),
runtime::intrinsic::get_value_in_cuda_kernel_args,
{kernel_args_, ir::Expr(i)},
{},
ir::CallType::Extern,
ir::FunctionRef(),
0);
ir::Expr stmt = ir::Let::Make(ir::Expr(args[i].var_arg()),
call_get_value_in_kernel_args);
ir::Expr let_symbol = ir::Expr(args[i].var_arg());
let_symbol->set_type(type_of<int64_t>());
ir::Expr stmt = ir::Let::Make(let_symbol, call_get_value_in_kernel_args);
arg_defs_.push_back(stmt);
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ struct CollectBucketStrategyHostFunctionVisitor
kernel_args_(KERNEL_ARGS, type_of<void*>()),
kernel_args_num_(KERNEL_ARGS_NUM, type_of<int>()),
kernel_stream_(KERNEL_STREAM, type_of<void*>()),
tensor_shape_args_(TENSOR_SHAPE_ARGS, type_of<int32_t**>()) {}
tensor_shape_args_(TENSOR_SHAPE_ARGS, type_of<int64_t**>()) {}

std::tuple<ir::Module, ir::Module> operator()(Expr* expr) {
ir::IRMutator<>::Visit(expr, expr);
Expand Down
6 changes: 5 additions & 1 deletion paddle/cinn/common/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,12 @@ Expr IndiceToAbsOffset(const std::vector<Expr> &shape,
VLOG(3) << "indices is : " << utils::Join(indices, ",");
CHECK_LE(shape.size(), indices.size());
Expr res;
ir::TryElevateInt32ToInt64(shape);
for (int i = 0; i < shape.size(); i++) {
CHECK_EQ(shape[i].type(), Int(32));
CHECK(shape[i].type() == Int(64) || shape[i].type() == Int(32))
<< "The shape data type currently supports only int32 or int64, but "
"the current data type of shape["
<< i << "] is " << shape[i].type();
Expr indice_prod = indices[i];
optim::SimplifyCast(&indice_prod);
for (int j = i + 1; j < shape.size(); j++) {
Expand Down
6 changes: 6 additions & 0 deletions paddle/cinn/common/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ inline Type type_of<int32_t**>() {
return x;
}
template <>
inline Type type_of<int64_t**>() {
Type x = Int(64);
x.set_cpp_handle2();
return x;
}
template <>
inline Type type_of<void*>() {
Type x = type_of<void>();
x.set_cpp_handle();
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ ir::Expr OpLowererImpl::DoGroupSchedule(
auto master_loops = ir_sch.GetLoops(GetNodeData(master)->id());
std::vector<int> splits;
for (auto loop : master_loops) {
splits.push_back(loop.As<ir::For>()->extent.as_int32());
splits.push_back(loop.As<ir::For>()->extent.as_int64());
}
loops = ir_sch.GetLoops(GetNodeData(node)->id());
ir_sch.Split(loops[0], splits);
Expand Down
8 changes: 4 additions & 4 deletions paddle/cinn/hlir/framework/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1537,8 +1537,8 @@ void MergeReduceLoop(
auto dst_loops = ir_sch.GetLoops(tensor_->name);
auto src_loops = ir_sch.GetLoops(tensor__->name);
int index = -1;
while (src_loops[index + 1].As<ir::For>()->extent.as_int32() ==
dst_loops[index + 1].As<ir::For>()->extent.as_int32()) {
while (src_loops[index + 1].As<ir::For>()->extent.as_int64() ==
dst_loops[index + 1].As<ir::For>()->extent.as_int64()) {
++index;
if (src_loops.size() == index + 1 || dst_loops.size() == index + 1) {
break;
Expand Down Expand Up @@ -1661,8 +1661,8 @@ void LoopComputeAt(
int index = std::min(node_loops.size(), master_loops.size()) - 1;
do {
// if loop range is not equal.
if (node_loops[index].As<ir::For>()->extent.as_int32() !=
master_loops[index].As<ir::For>()->extent.as_int32()) {
if (node_loops[index].As<ir::For>()->extent.as_int64() !=
master_loops[index].As<ir::For>()->extent.as_int64()) {
continue;
}
MergeLoops(
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
}
int_args_set.insert(symbol_name);
group_func_args->emplace_back(
ir::_Var_::Make(symbol_name, cinn::common::Int(32)));
ir::_Var_::Make(symbol_name, cinn::common::Int(64)));
group->int_args_map[non_tensor_arg_idx++] = {tensor_arg_idx,
tensor_arg_dim_idx};
VLOG(4) << "device kernel func's " << non_tensor_arg_idx << " is from "
Expand Down Expand Up @@ -860,7 +860,7 @@ ir::LoweredFunc OpLowererImpl::GenerateInferShapeFunc(
int tensor_dim_size = tensor_dim.size();
auto tensor_shape = group_func_arg_tensors[tensor_arg_idx]->shape;

ir::Var tensor_shape_args(TENSOR_SHAPE_ARGS, type_of<int32_t**>());
ir::Var tensor_shape_args(TENSOR_SHAPE_ARGS, type_of<int64_t**>());
for (int i = 0; i < tensor_shape.size(); i++) {
ir::Expr call_set_infer_shape_value =
ir::Call::Make(type_of<void>(),
Expand Down
10 changes: 7 additions & 3 deletions paddle/cinn/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,15 @@ Var _Buffer_::buffer_addr() const {
return _Var_::Make(name, thetype);
}

int _Buffer_::numel() const {
int res = 1;
int64_t _Buffer_::numel() const {
int64_t res = 1;
for (auto &i : shape) {
CHECK(i.is_constant());
res *= i.as_int32();
if (i->type() == Int(64)) {
res *= i.as_int64();
} else {
res *= i.as_int32();
}
}
return res;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/ir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class _Buffer_ : public ExprNode<_Buffer_> {

void Verify() const override;

int numel() const;
int64_t numel() const;

static const IrNodeTy _node_type_ = IrNodeTy::_Buffer_;

Expand Down
8 changes: 2 additions & 6 deletions paddle/cinn/ir/dim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/cinn/ir/dim.h"
#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/ir/ir.h"

namespace cinn {
Expand All @@ -31,12 +32,7 @@ Dim _Dim_::Make(const std::string& name, const symbol::DimExpr& sym_dim) {
auto* n = make_shared<_Dim_>();
n->name = name;
n->sym_dim = sym_dim;
if (sym_dim.isa<std::string>()) {
n->dim_expr =
Expr(Var(sym_dim.dyn_cast<std::string>(), cinn::common::Int(32)));
} else {
n->dim_expr = Expr(static_cast<int32_t>(sym_dim.dyn_cast<int64_t>()));
}
n->dim_expr = common::DimExprConverter().ConvertToIrExpr(sym_dim);

return Dim(n);
}
Expand Down
20 changes: 15 additions & 5 deletions paddle/cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Add::Add(Expr a, Expr b) : BinaryOpNode<Add>(a.type(), a, b) {}
void BinaryNodeVerify(const Expr &a, const Expr &b, absl::string_view ir_name) {
CHECK(a.defined());
CHECK(b.defined());
TryElevateInt32ToInt64({a, b});
CHECK_EQ(a.type(), b.type())
<< "The operands' types of the node [" << ir_name << "] don't match";
}
Expand All @@ -72,9 +73,7 @@ Expr Sub::Make(Expr a, Expr b) {
void Sub::Verify() const { BinaryNodeVerify(a(), b(), "Sub"); }

Expr Mul::Make(Expr a, Expr b) {
CHECK(a.defined());
CHECK(b.defined());
CHECK_EQ(a.type(), b.type()) << "a=" << a << ", b=" << b;
BinaryNodeVerify(a, b, "Mul");
auto node = make_shared<Mul>(a, b);
return Expr(node);
}
Expand Down Expand Up @@ -203,6 +202,7 @@ void Let::Verify() const {
CHECK(symbol.defined());
// The default value(contained in body) is not required.
if (body.defined()) {
TryElevateInt32ToInt64({symbol, body});
CHECK_EQ(symbol.type(), body.type());
}
}
Expand Down Expand Up @@ -583,7 +583,11 @@ Var &Var::operator=(const _Var_ *x) {
Expr Load::Make(Expr tensor, const std::vector<Expr> &indices) {
CHECK(tensor->type().valid());
CHECK(!indices.empty());
for (auto &idx : indices) CHECK_EQ(idx.type().ElementOf(), Int(32));
TryElevateInt32ToInt64(indices);
for (auto &idx : indices) {
CHECK(idx.type().ElementOf() == Int(64) ||
idx.type().ElementOf() == Int(32));
}
auto node = make_shared<Load>();
node->tensor = tensor;
node->indices = indices;
Expand Down Expand Up @@ -695,8 +699,13 @@ Expr Sum::Make(const std::vector<Expr> &vs) {
if (vs.size() == 1) return vs.front();

auto *n = make_shared<Sum>();
TryElevateInt32ToInt64(vs);
auto type = vs.front().type();
for (auto &v : vs) CHECK_EQ(v.type(), type) << vs.front() << " " << v;
for (auto &v : vs) {
CHECK_EQ(v.type(), type) << "The operands' types of the node ["
<< n->node_type() << "] don't match: "
<< "(" << v << " vs " << vs.front() << ")";
}

n->operands() = vs;

Expand All @@ -709,6 +718,7 @@ Expr Product::Make(const std::vector<Expr> &vs) {
CHECK_GE(vs.size(), 1);

auto *n = make_shared<Product>();
TryElevateInt32ToInt64(vs);
auto type = vs.front().type();
for (auto &v : vs) CHECK_EQ(v.type(), type);

Expand Down
38 changes: 37 additions & 1 deletion paddle/cinn/ir/ir_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ int32_t Expr::as_int32() const {
return As<IntImm>()->value;
}
int64_t Expr::as_int64() const {
CHECK(type().is_int(64));
CHECK(type().is_int(64) || type().is_int(32));
return As<IntImm>()->value;
}

Expand Down Expand Up @@ -235,5 +235,41 @@ const Expr &IrNode::operand(int i) {
return operands[i];
}

void IrNode::set_type(Type type) { type_ = type; }

void IrNode::convert_int32_to_int64() {
CHECK(type_ == Int(64) || type_ == Int(32) || type_.is_unk())
<< "Current only support convert int32_t to int64_t, but get type is "
<< type_;
type_ = Int(64);
for (Expr &operand : operands) {
operand->convert_int32_to_int64();
}
}

void TryElevateInt32ToInt64(const std::vector<Expr> &expr_vec) {
Type type = expr_vec.front()->type();
for (const Expr &expr : expr_vec) {
if (expr->type() == Int(64)) {
type = Int(64);
break;
}
}

// Not need Elevate to Int(64)
if (type != Int(64)) {
return;
}
for (const Expr &expr : expr_vec) {
CHECK(expr->type() == Int(64) || expr->type() == Int(32) ||
expr->type().is_unk())
<< "Current only support convert int32_t to int64_t, but get type is "
<< expr->type();
if (expr->type() == Int(32)) {
expr->convert_int32_to_int64();
}
}
}

} // namespace ir
} // namespace cinn
6 changes: 5 additions & 1 deletion paddle/cinn/ir/ir_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ class IrNode : public cinn::common::Object {

virtual IrNodeTy node_type() const { return IrNodeTy::kUnk; }
virtual Type type() const { return type_; }
void set_type(Type type) { type_ = type; }
void set_type(Type type);
//! Elevate int32 to int64 if needed
void convert_int32_to_int64();

//! Get i-th operand
const Expr& operand(int i);
Expand Down Expand Up @@ -502,6 +504,8 @@ Expr ExprNode<T>::Copy() const {
return Expr();
}

void TryElevateInt32ToInt64(const std::vector<Expr>& expr_vec);

} // namespace ir
} // namespace cinn

Expand Down
6 changes: 5 additions & 1 deletion paddle/cinn/ir/lowered_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ void _LoweredFunc_::PrepareArgumentExprs() {
// cast arg to cinn_pod_value_t*

// something like `_args[0]`
Expr load_expr = Load::Make(pod_value_ptr, {cinn::common::make_const(i)});
Expr load_expr = Load::Make(
pod_value_ptr, {cinn::common::make_const(static_cast<int32_t>(i))});
CHECK_EQ(load_expr.type(), type_of<cinn_pod_value_t>());
load_expr = ir::intrinsics::GetAddr::Make(load_expr);

Expand Down Expand Up @@ -404,6 +405,9 @@ void _LoweredFunc_::PrepareArgumentExprs() {
} else if (arg.type() == type_of<int32_t**>()) {
pod_cast_expr =
ir::intrinsics::PodValueToX::Make(load_expr, type_of<int32_t**>());
} else if (arg.type() == type_of<int64_t**>()) {
pod_cast_expr =
ir::intrinsics::PodValueToX::Make(load_expr, type_of<int64_t**>());
} else if (arg.type() == type_of<void**>()) {
pod_cast_expr =
ir::intrinsics::PodValueToX::Make(load_expr, type_of<void**>());
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/ir/schedule/impl/for_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ void DyScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
const std::array<int, 3> kMaxBlockDims = cur_dev_info->GetMaxBlockDims();
const std::array<int, 3> kMaxGridDims = cur_dev_info->GetMaxGridDims();
auto check_offset = [&](const char& c) -> bool {
auto extent = loop.As<ir::For>()->extent.as_int32();
auto extent = loop.As<ir::For>()->extent.as_int64();
return extent <= (c == 'b' ? kMaxGridDims[offset] : kMaxBlockDims[offset]);
};
if (thread_axis[0] == 'b') {
Expand Down Expand Up @@ -210,7 +210,7 @@ void StScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
const std::array<int, 3> kMaxBlockDims = cur_dev_info->GetMaxBlockDims();
const std::array<int, 3> kMaxGridDims = cur_dev_info->GetMaxGridDims();
auto check_offset = [&](const char& c) -> bool {
auto extent = loop.As<ir::For>()->extent.as_int32();
auto extent = loop.As<ir::For>()->extent.as_int64();
return extent <= (c == 'b' ? kMaxGridDims[offset] : kMaxBlockDims[offset]);
};
if (thread_axis[0] == 'b') {
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/ir/schedule/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ std::vector<Expr> IRSchedule::Split(const Expr& loop,
const std::vector<int>& factors) {
if (IsDynamicShape()) return impl_->Split(loop, factors);
std::vector<Expr> decision = SamplePerfectTile(
loop, factors.size(), loop.As<ir::For>()->extent.as_int32(), factors);
loop, factors.size(), loop.As<ir::For>()->extent.as_int64(), factors);
auto results = Split(loop, decision);
return results;
}
Expand All @@ -407,7 +407,7 @@ std::vector<Expr> IRSchedule::Split(const Expr& loop,
std::vector<int> int_factors;
std::vector<Expr> results;
std::for_each(factors.begin(), factors.end(), [&int_factors](const Expr& e) {
if (e.is_constant()) int_factors.push_back(e.as_int32());
if (e.is_constant()) int_factors.push_back(e.as_int64());
});
if (int_factors.size() == factors.size()) {
results = impl_->Split(loop, int_factors);
Expand Down
26 changes: 21 additions & 5 deletions paddle/cinn/ir/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,28 @@ isl::set _Tensor_::GenerateIslDomain() const {
auto _axis_with_reduce = axis_with_reduce();
for (int i = 0; i < domain.size(); i++) {
auto dim = domain[i];
if (dim.is_constant()) {
dims.emplace_back(_axis_with_reduce[i]->name, 0, dim.as_int32() - 1);
if (dim.type() == type_of<int64_t>()) {
if (dim.is_constant()) {
dims.emplace_back(_axis_with_reduce[i]->name,
static_cast<int64_t>(0),
static_cast<int64_t>(dim.as_int64() - 1));
} else {
dims.emplace_back(
_axis_with_reduce[i]->name,
Expr(static_cast<int64_t>(0)),
Sub::Make(dim,
cinn::common::make_const(static_cast<int64_t>(1))));
}
} else {
dims.emplace_back(_axis_with_reduce[i]->name,
Expr(0),
Sub::Make(dim, cinn::common::make_const(1)));
if (dim.is_constant()) {
dims.emplace_back(_axis_with_reduce[i]->name,
static_cast<uint32_t>(0),
dim.as_int32() - 1);
} else {
dims.emplace_back(_axis_with_reduce[i]->name,
Expr(0),
Sub::Make(dim, cinn::common::make_const(1)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make "make_const" with type template? Now it seems make_const only returns int32

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is intended to be compatible with previous int32,
and if you pass the shape property through dim_expr, you will not enter this branch of the code.

}
}
}
}
Expand Down
Loading