diff --git a/paddle/cinn/backends/codegen_cuda_host.cc b/paddle/cinn/backends/codegen_cuda_host.cc index 11e986bb9ace1..71b11f228acd6 100644 --- a/paddle/cinn/backends/codegen_cuda_host.cc +++ b/paddle/cinn/backends/codegen_cuda_host.cc @@ -222,7 +222,7 @@ llvm::Value* CodeGenCUDA_Host::LowerHostFunc(const ir::_LoweredFunc_* func) { llvm::Value* CodeGenCUDA_Host::LowerParseArgsValueCall( const ir::Call* call_ir) { - auto ret_type = CinnTypeToLLVMType(Int(32), m_); + auto ret_type = CinnTypeToLLVMType(Int(64), m_); std::vector args_type; CHECK_EQ(call_ir->read_args.size(), 2); CHECK(call_ir->read_args[0].is_var() && diff --git a/paddle/cinn/backends/codegen_cuda_util.cc b/paddle/cinn/backends/codegen_cuda_util.cc index 660eee9160a6b..1f9966b5b2881 100644 --- a/paddle/cinn/backends/codegen_cuda_util.cc +++ b/paddle/cinn/backends/codegen_cuda_util.cc @@ -114,15 +114,16 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs( for (int i = 0; i < args.size(); ++i) { if (args[i].is_var()) { ir::Expr call_get_value_in_kernel_args = - ir::Call::Make(Int(32), + ir::Call::Make(Int(64), runtime::intrinsic::get_value_in_cuda_kernel_args, {kernel_args_, ir::Expr(i)}, {}, ir::CallType::Extern, ir::FunctionRef(), 0); - ir::Expr stmt = ir::Let::Make(ir::Expr(args[i].var_arg()), - call_get_value_in_kernel_args); + ir::Expr let_symbol = ir::Expr(args[i].var_arg()); + let_symbol->set_type(type_of()); + ir::Expr stmt = ir::Let::Make(let_symbol, call_get_value_in_kernel_args); arg_defs_.push_back(stmt); } } diff --git a/paddle/cinn/backends/codegen_cuda_util.h b/paddle/cinn/backends/codegen_cuda_util.h index 52296bd2a8807..01caff457a50c 100644 --- a/paddle/cinn/backends/codegen_cuda_util.h +++ b/paddle/cinn/backends/codegen_cuda_util.h @@ -152,7 +152,7 @@ struct CollectBucketStrategyHostFunctionVisitor kernel_args_(KERNEL_ARGS, type_of()), kernel_args_num_(KERNEL_ARGS_NUM, type_of()), kernel_stream_(KERNEL_STREAM, type_of()), - tensor_shape_args_(TENSOR_SHAPE_ARGS, type_of()) {} + tensor_shape_args_(TENSOR_SHAPE_ARGS, type_of()) {} std::tuple operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); diff --git a/paddle/cinn/common/ir_util.cc b/paddle/cinn/common/ir_util.cc index 774d7514e6fb2..d326e652a7be7 100644 --- a/paddle/cinn/common/ir_util.cc +++ b/paddle/cinn/common/ir_util.cc @@ -143,8 +143,12 @@ Expr IndiceToAbsOffset(const std::vector &shape, VLOG(3) << "indices is : " << utils::Join(indices, ","); CHECK_LE(shape.size(), indices.size()); Expr res; + ir::TryElevateInt32ToInt64(shape); for (int i = 0; i < shape.size(); i++) { - CHECK_EQ(shape[i].type(), Int(32)); + CHECK(shape[i].type() == Int(64) || shape[i].type() == Int(32)) + << "The shape data type currently supports only int32 or int64, but " + "the current data type of shape[" + << i << "] is " << shape[i].type(); Expr indice_prod = indices[i]; optim::SimplifyCast(&indice_prod); for (int j = i + 1; j < shape.size(); j++) { diff --git a/paddle/cinn/common/type.h b/paddle/cinn/common/type.h index b11a320bbd5a1..420a31b5824c2 100644 --- a/paddle/cinn/common/type.h +++ b/paddle/cinn/common/type.h @@ -263,6 +263,12 @@ inline Type type_of() { return x; } template <> +inline Type type_of() { + Type x = Int(64); + x.set_cpp_handle2(); + return x; +} +template <> inline Type type_of() { Type x = type_of(); x.set_cpp_handle(); diff --git a/paddle/cinn/hlir/framework/op_lowering_impl.cc b/paddle/cinn/hlir/framework/op_lowering_impl.cc index 1b3a39850e2e4..cef5968639511 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/op_lowering_impl.cc @@ -599,7 +599,7 @@ ir::Expr OpLowererImpl::DoGroupSchedule( auto master_loops = ir_sch.GetLoops(GetNodeData(master)->id()); std::vector splits; for (auto loop : master_loops) { - splits.push_back(loop.As()->extent.as_int32()); + splits.push_back(loop.As()->extent.as_int64()); } loops = ir_sch.GetLoops(GetNodeData(node)->id()); ir_sch.Split(loops[0], splits); diff --git a/paddle/cinn/hlir/framework/op_lowering_util.cc b/paddle/cinn/hlir/framework/op_lowering_util.cc index 5a332324c7c89..a7b988a735cdb 100644 --- a/paddle/cinn/hlir/framework/op_lowering_util.cc +++ b/paddle/cinn/hlir/framework/op_lowering_util.cc @@ -1537,8 +1537,8 @@ void MergeReduceLoop( auto dst_loops = ir_sch.GetLoops(tensor_->name); auto src_loops = ir_sch.GetLoops(tensor__->name); int index = -1; - while (src_loops[index + 1].As()->extent.as_int32() == - dst_loops[index + 1].As()->extent.as_int32()) { + while (src_loops[index + 1].As()->extent.as_int64() == + dst_loops[index + 1].As()->extent.as_int64()) { ++index; if (src_loops.size() == index + 1 || dst_loops.size() == index + 1) { break; @@ -1661,8 +1661,8 @@ void LoopComputeAt( int index = std::min(node_loops.size(), master_loops.size()) - 1; do { // if loop range is not equal. - if (node_loops[index].As()->extent.as_int32() != - master_loops[index].As()->extent.as_int32()) { + if (node_loops[index].As()->extent.as_int64() != + master_loops[index].As()->extent.as_int64()) { continue; } MergeLoops( diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index 59ee965a4b91a..44f78f062874f 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -494,7 +494,7 @@ std::vector OpLowererImpl::PostProcess( } int_args_set.insert(symbol_name); group_func_args->emplace_back( - ir::_Var_::Make(symbol_name, cinn::common::Int(32))); + ir::_Var_::Make(symbol_name, cinn::common::Int(64))); group->int_args_map[non_tensor_arg_idx++] = {tensor_arg_idx, tensor_arg_dim_idx}; VLOG(4) << "device kernel func's " << non_tensor_arg_idx << " is from " @@ -860,7 +860,7 @@ ir::LoweredFunc OpLowererImpl::GenerateInferShapeFunc( int tensor_dim_size = tensor_dim.size(); auto tensor_shape = group_func_arg_tensors[tensor_arg_idx]->shape; - ir::Var tensor_shape_args(TENSOR_SHAPE_ARGS, type_of()); + ir::Var tensor_shape_args(TENSOR_SHAPE_ARGS, type_of()); for (int i = 0; i < tensor_shape.size(); i++) { ir::Expr call_set_infer_shape_value = ir::Call::Make(type_of(), diff --git a/paddle/cinn/ir/buffer.cc b/paddle/cinn/ir/buffer.cc index ada0d4487b7f0..350cde0189fdf 100644 --- a/paddle/cinn/ir/buffer.cc +++ b/paddle/cinn/ir/buffer.cc @@ -103,11 +103,15 @@ Var _Buffer_::buffer_addr() const { return _Var_::Make(name, thetype); } -int _Buffer_::numel() const { - int res = 1; +int64_t _Buffer_::numel() const { + int64_t res = 1; for (auto &i : shape) { CHECK(i.is_constant()); - res *= i.as_int32(); + if (i->type() == Int(64)) { + res *= i.as_int64(); + } else { + res *= i.as_int32(); + } } return res; } diff --git a/paddle/cinn/ir/buffer.h b/paddle/cinn/ir/buffer.h index 7e80b6de9297f..4b83a2bcd2e0f 100755 --- a/paddle/cinn/ir/buffer.h +++ b/paddle/cinn/ir/buffer.h @@ -141,7 +141,7 @@ class _Buffer_ : public ExprNode<_Buffer_> { void Verify() const override; - int numel() const; + int64_t numel() const; static const IrNodeTy _node_type_ = IrNodeTy::_Buffer_; diff --git a/paddle/cinn/ir/dim.cc b/paddle/cinn/ir/dim.cc index 98ab391872091..fe63fb31158a9 100644 --- a/paddle/cinn/ir/dim.cc +++ b/paddle/cinn/ir/dim.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/cinn/ir/dim.h" +#include "paddle/cinn/common/dim_expr_converter.h" #include "paddle/cinn/ir/ir.h" namespace cinn { @@ -31,12 +32,7 @@ Dim _Dim_::Make(const std::string& name, const symbol::DimExpr& sym_dim) { auto* n = make_shared<_Dim_>(); n->name = name; n->sym_dim = sym_dim; - if (sym_dim.isa()) { - n->dim_expr = - Expr(Var(sym_dim.dyn_cast(), cinn::common::Int(32))); - } else { - n->dim_expr = Expr(static_cast(sym_dim.dyn_cast())); - } + n->dim_expr = common::DimExprConverter().ConvertToIrExpr(sym_dim); return Dim(n); } diff --git a/paddle/cinn/ir/ir.cc b/paddle/cinn/ir/ir.cc index b556dad00cb32..d57344e77d238 100644 --- a/paddle/cinn/ir/ir.cc +++ b/paddle/cinn/ir/ir.cc @@ -58,6 +58,7 @@ Add::Add(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} void BinaryNodeVerify(const Expr &a, const Expr &b, absl::string_view ir_name) { CHECK(a.defined()); CHECK(b.defined()); + TryElevateInt32ToInt64({a, b}); CHECK_EQ(a.type(), b.type()) << "The operands' types of the node [" << ir_name << "] don't match"; } @@ -72,9 +73,7 @@ Expr Sub::Make(Expr a, Expr b) { void Sub::Verify() const { BinaryNodeVerify(a(), b(), "Sub"); } Expr Mul::Make(Expr a, Expr b) { - CHECK(a.defined()); - CHECK(b.defined()); - CHECK_EQ(a.type(), b.type()) << "a=" << a << ", b=" << b; + BinaryNodeVerify(a, b, "Mul"); auto node = make_shared(a, b); return Expr(node); } @@ -203,6 +202,7 @@ void Let::Verify() const { CHECK(symbol.defined()); // The default value(contained in body) is not required. if (body.defined()) { + TryElevateInt32ToInt64({symbol, body}); CHECK_EQ(symbol.type(), body.type()); } } @@ -583,7 +583,11 @@ Var &Var::operator=(const _Var_ *x) { Expr Load::Make(Expr tensor, const std::vector &indices) { CHECK(tensor->type().valid()); CHECK(!indices.empty()); - for (auto &idx : indices) CHECK_EQ(idx.type().ElementOf(), Int(32)); + TryElevateInt32ToInt64(indices); + for (auto &idx : indices) { + CHECK(idx.type().ElementOf() == Int(64) || + idx.type().ElementOf() == Int(32)); + } auto node = make_shared(); node->tensor = tensor; node->indices = indices; @@ -695,8 +699,13 @@ Expr Sum::Make(const std::vector &vs) { if (vs.size() == 1) return vs.front(); auto *n = make_shared(); + TryElevateInt32ToInt64(vs); auto type = vs.front().type(); - for (auto &v : vs) CHECK_EQ(v.type(), type) << vs.front() << " " << v; + for (auto &v : vs) { + CHECK_EQ(v.type(), type) << "The operands' types of the node [" + << n->node_type() << "] don't match: " + << "(" << v << " vs " << vs.front() << ")"; + } n->operands() = vs; @@ -709,6 +718,7 @@ Expr Product::Make(const std::vector &vs) { CHECK_GE(vs.size(), 1); auto *n = make_shared(); + TryElevateInt32ToInt64(vs); auto type = vs.front().type(); for (auto &v : vs) CHECK_EQ(v.type(), type); diff --git a/paddle/cinn/ir/ir_base.cc b/paddle/cinn/ir/ir_base.cc index ed1980511d686..b89342662eb7c 100644 --- a/paddle/cinn/ir/ir_base.cc +++ b/paddle/cinn/ir/ir_base.cc @@ -119,7 +119,7 @@ int32_t Expr::as_int32() const { return As()->value; } int64_t Expr::as_int64() const { - CHECK(type().is_int(64)); + CHECK(type().is_int(64) || type().is_int(32)); return As()->value; } @@ -235,5 +235,41 @@ const Expr &IrNode::operand(int i) { return operands[i]; } +void IrNode::set_type(Type type) { type_ = type; } + +void IrNode::convert_int32_to_int64() { + CHECK(type_ == Int(64) || type_ == Int(32) || type_.is_unk()) + << "Current only support convert int32_t to int64_t, but get type is " + << type_; + type_ = Int(64); + for (Expr &operand : operands) { + operand->convert_int32_to_int64(); + } +} + +void TryElevateInt32ToInt64(const std::vector &expr_vec) { + Type type = expr_vec.front()->type(); + for (const Expr &expr : expr_vec) { + if (expr->type() == Int(64)) { + type = Int(64); + break; + } + } + + // Not need Elevate to Int(64) + if (type != Int(64)) { + return; + } + for (const Expr &expr : expr_vec) { + CHECK(expr->type() == Int(64) || expr->type() == Int(32) || + expr->type().is_unk()) + << "Current only support convert int32_t to int64_t, but get type is " + << expr->type(); + if (expr->type() == Int(32)) { + expr->convert_int32_to_int64(); + } + } +} + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/ir_base.h b/paddle/cinn/ir/ir_base.h index 0047100ebcfdf..24a7c2271d1fd 100644 --- a/paddle/cinn/ir/ir_base.h +++ b/paddle/cinn/ir/ir_base.h @@ -162,7 +162,9 @@ class IrNode : public cinn::common::Object { virtual IrNodeTy node_type() const { return IrNodeTy::kUnk; } virtual Type type() const { return type_; } - void set_type(Type type) { type_ = type; } + void set_type(Type type); + //! Elevate int32 to int64 if needed + void convert_int32_to_int64(); //! Get i-th operand const Expr& operand(int i); @@ -502,6 +504,8 @@ Expr ExprNode::Copy() const { return Expr(); } +void TryElevateInt32ToInt64(const std::vector& expr_vec); + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/lowered_func.cc b/paddle/cinn/ir/lowered_func.cc index d252a5e44954f..3537bfaf2fe4f 100644 --- a/paddle/cinn/ir/lowered_func.cc +++ b/paddle/cinn/ir/lowered_func.cc @@ -333,7 +333,8 @@ void _LoweredFunc_::PrepareArgumentExprs() { // cast arg to cinn_pod_value_t* // something like `_args[0]` - Expr load_expr = Load::Make(pod_value_ptr, {cinn::common::make_const(i)}); + Expr load_expr = Load::Make( + pod_value_ptr, {cinn::common::make_const(static_cast(i))}); CHECK_EQ(load_expr.type(), type_of()); load_expr = ir::intrinsics::GetAddr::Make(load_expr); @@ -404,6 +405,9 @@ void _LoweredFunc_::PrepareArgumentExprs() { } else if (arg.type() == type_of()) { pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); diff --git a/paddle/cinn/ir/schedule/impl/for_type.cc b/paddle/cinn/ir/schedule/impl/for_type.cc index 6b045fcc2b342..bab2b312bde12 100644 --- a/paddle/cinn/ir/schedule/impl/for_type.cc +++ b/paddle/cinn/ir/schedule/impl/for_type.cc @@ -132,7 +132,7 @@ void DyScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) { const std::array kMaxBlockDims = cur_dev_info->GetMaxBlockDims(); const std::array kMaxGridDims = cur_dev_info->GetMaxGridDims(); auto check_offset = [&](const char& c) -> bool { - auto extent = loop.As()->extent.as_int32(); + auto extent = loop.As()->extent.as_int64(); return extent <= (c == 'b' ? kMaxGridDims[offset] : kMaxBlockDims[offset]); }; if (thread_axis[0] == 'b') { @@ -210,7 +210,7 @@ void StScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) { const std::array kMaxBlockDims = cur_dev_info->GetMaxBlockDims(); const std::array kMaxGridDims = cur_dev_info->GetMaxGridDims(); auto check_offset = [&](const char& c) -> bool { - auto extent = loop.As()->extent.as_int32(); + auto extent = loop.As()->extent.as_int64(); return extent <= (c == 'b' ? kMaxGridDims[offset] : kMaxBlockDims[offset]); }; if (thread_axis[0] == 'b') { diff --git a/paddle/cinn/ir/schedule/ir_schedule.cc b/paddle/cinn/ir/schedule/ir_schedule.cc index fb151051f0b67..b4c44d062e47b 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.cc +++ b/paddle/cinn/ir/schedule/ir_schedule.cc @@ -384,7 +384,7 @@ std::vector IRSchedule::Split(const Expr& loop, const std::vector& factors) { if (IsDynamicShape()) return impl_->Split(loop, factors); std::vector decision = SamplePerfectTile( - loop, factors.size(), loop.As()->extent.as_int32(), factors); + loop, factors.size(), loop.As()->extent.as_int64(), factors); auto results = Split(loop, decision); return results; } @@ -407,7 +407,7 @@ std::vector IRSchedule::Split(const Expr& loop, std::vector int_factors; std::vector results; std::for_each(factors.begin(), factors.end(), [&int_factors](const Expr& e) { - if (e.is_constant()) int_factors.push_back(e.as_int32()); + if (e.is_constant()) int_factors.push_back(e.as_int64()); }); if (int_factors.size() == factors.size()) { results = impl_->Split(loop, int_factors); diff --git a/paddle/cinn/ir/tensor.cc b/paddle/cinn/ir/tensor.cc index 261db949b997b..f6897b81560dd 100644 --- a/paddle/cinn/ir/tensor.cc +++ b/paddle/cinn/ir/tensor.cc @@ -227,12 +227,28 @@ isl::set _Tensor_::GenerateIslDomain() const { auto _axis_with_reduce = axis_with_reduce(); for (int i = 0; i < domain.size(); i++) { auto dim = domain[i]; - if (dim.is_constant()) { - dims.emplace_back(_axis_with_reduce[i]->name, 0, dim.as_int32() - 1); + if (dim.type() == type_of()) { + if (dim.is_constant()) { + dims.emplace_back(_axis_with_reduce[i]->name, + static_cast(0), + static_cast(dim.as_int64() - 1)); + } else { + dims.emplace_back( + _axis_with_reduce[i]->name, + Expr(static_cast(0)), + Sub::Make(dim, + cinn::common::make_const(static_cast(1)))); + } } else { - dims.emplace_back(_axis_with_reduce[i]->name, - Expr(0), - Sub::Make(dim, cinn::common::make_const(1))); + if (dim.is_constant()) { + dims.emplace_back(_axis_with_reduce[i]->name, + static_cast(0), + dim.as_int32() - 1); + } else { + dims.emplace_back(_axis_with_reduce[i]->name, + Expr(0), + Sub::Make(dim, cinn::common::make_const(1))); + } } } } diff --git a/paddle/cinn/optim/resize_buffer.cc b/paddle/cinn/optim/resize_buffer.cc index dda54c44c0aad..f36eef0704946 100644 --- a/paddle/cinn/optim/resize_buffer.cc +++ b/paddle/cinn/optim/resize_buffer.cc @@ -125,10 +125,11 @@ class AnalyzeLoopVarRange : public ir::IRMutator<> { for (int i = 0; i < indice_extent.size(); ++i) { if (stored_indice_extent[i].is_constant() && indice_extent[i].is_constant()) { - int stored_extent = stored_indice_extent[i].as_int32(); - int cur_extent = indice_extent[i].as_int32(); + int64_t stored_extent = stored_indice_extent[i].as_int64(); + int64_t cur_extent = indice_extent[i].as_int64(); if (cur_extent > stored_extent) { stored_indice_extent[i] = ir::Expr(cur_extent); + stored_indice_extent[i]->set_type(indice_extent[i].type()); } } // if there indice extent is not constant, which means dynamic shape diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc index 9923219235428..06da7f56c140a 100644 --- a/paddle/cinn/optim/transform_gpu_forloop.cc +++ b/paddle/cinn/optim/transform_gpu_forloop.cc @@ -408,7 +408,7 @@ class ReplaceVarToZero : public ir::IRMutator<> { auto var_name = for_ir->loop_var->name; auto extent_i = for_ir->extent; - if (extent_i.is_constant() && extent_i.as_int32() == 1) + if (extent_i.is_constant() && extent_i.as_int64() == 1) loop_var_.insert(var_name); ir::IRMutator<>::Visit(op, expr); loop_var_.erase(var_name); diff --git a/paddle/cinn/optim/unroll_loops.cc b/paddle/cinn/optim/unroll_loops.cc index 2bc7df1184477..9f2e8bf244e4c 100644 --- a/paddle/cinn/optim/unroll_loops.cc +++ b/paddle/cinn/optim/unroll_loops.cc @@ -65,7 +65,7 @@ struct UnrollMutator : public ir::IRMutator { VLOG(5) << "loop to be unrolled should have a contant extent"; return; } - int extent = op->extent.as_int32(); + int64_t extent = op->extent.as_int64(); // predicate this for-loop can be unrolled by auto-unroll conditions bool unrollable = @@ -109,7 +109,7 @@ struct UnrollMutator : public ir::IRMutator { int max_unroll_extent_ = 50; // the number of steps that have been unrolled or plain statement - int flat_step_ = 0; + int64_t flat_step_ = 0; // the number of nested loops not to be unrolled int not_unrolled_depth_ = 0; }; diff --git a/paddle/cinn/optim/vectorize_loops.cc b/paddle/cinn/optim/vectorize_loops.cc index c807e3210824d..69bd39f8a4c92 100644 --- a/paddle/cinn/optim/vectorize_loops.cc +++ b/paddle/cinn/optim/vectorize_loops.cc @@ -731,7 +731,8 @@ struct VectorizeLoops_ : public IRMutator { if (forloop->extent.As()) { var_intervals.emplace( loopvar_name, - cinn::common::CasInterval{0, forloop->extent.as_int32() - 1}); + cinn::common::CasInterval{static_cast(0), + forloop->extent.as_int64() - 1}); } else { var_intervals.emplace( loopvar_name, diff --git a/paddle/cinn/poly/dim.h b/paddle/cinn/poly/dim.h index 5ae7ee7a897d6..15a3999cc7a3a 100644 --- a/paddle/cinn/poly/dim.h +++ b/paddle/cinn/poly/dim.h @@ -52,6 +52,10 @@ struct Dim { Dim(std::string id, uint32_t lower_bound, uint32_t upper_bound) : id(std::move(id)), lower_bound(lower_bound), upper_bound(upper_bound) {} + //! Construct a dimension with int64_t range. + Dim(std::string id, int64_t lower_bound, int64_t upper_bound) + : id(std::move(id)), lower_bound(lower_bound), upper_bound(upper_bound) {} + //! Construct a dimension with expression range. Dim(std::string id, ir::Expr lower_bound, ir::Expr upper_bound); diff --git a/paddle/cinn/poly/domain.cc b/paddle/cinn/poly/domain.cc index c6f4479bf8bba..08b6da5ef0447 100644 --- a/paddle/cinn/poly/domain.cc +++ b/paddle/cinn/poly/domain.cc @@ -61,8 +61,35 @@ std::string Domain::__str__() const { } isl::set Domain::to_isl() const { + // TODO(6clc): will be removed in future VLOG(3) << "isl::set " << __str__(); - isl::set x(cinn::common::Context::isl_ctx(), __str__()); + auto replace_substr = [](std::string& s, + std::string const& toReplace, + std::string const& replaceWith) { + std::string buf; + std::size_t pos = 0; + std::size_t prevPos = -1; + + // Reserves rough estimate of final size of string. + buf.reserve(s.size()); + + while (true) { + prevPos = pos; + pos = s.find(toReplace, pos); + if (pos == std::string::npos) break; + buf.append(s, prevPos, pos - prevPos); + buf += replaceWith; + pos += toReplace.size(); + } + + buf.append(s, prevPos, s.size() - prevPos); + s.swap(buf); + }; + + std::string isl_string = __str__(); + replace_substr(isl_string, "(ll)", ""); + replace_substr(isl_string, "ll", ""); + isl::set x(cinn::common::Context::isl_ctx(), isl_string); return x; } diff --git a/paddle/cinn/runtime/cuda/cuda_intrinsics.cc b/paddle/cinn/runtime/cuda/cuda_intrinsics.cc index c4f335603963b..a990192a1d1e6 100644 --- a/paddle/cinn/runtime/cuda/cuda_intrinsics.cc +++ b/paddle/cinn/runtime/cuda/cuda_intrinsics.cc @@ -440,8 +440,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .SetRetType() .AddInputType() .AddInputType() - .AddInputType() - .AddInputType() + .AddInputType() + .AddInputType() .End(); using cinn::runtime::cuda::cinn_call_cuda_kernel; diff --git a/paddle/cinn/runtime/cuda/cuda_util.cc b/paddle/cinn/runtime/cuda/cuda_util.cc index 98ba1c52d7edc..a33427df4fce1 100644 --- a/paddle/cinn/runtime/cuda/cuda_util.cc +++ b/paddle/cinn/runtime/cuda/cuda_util.cc @@ -78,7 +78,7 @@ class CublasHandle { cublasHandle_t cuhandle; }; -int32_t cinn_get_value_in_cuda_kernel_args(void *v_args, int idx) { +int64_t cinn_get_value_in_cuda_kernel_args(void *v_args, int idx) { cinn_pod_value_t *args = static_cast(v_args); return args[idx].operator int64_t(); } @@ -2748,7 +2748,7 @@ void cinn_gpu_cudnn_pool2d(const std::vector &attrs, cudnnDestroyPoolingDescriptor(pooling_desc); } -void infer_shape_set_value(int row, int col, int32_t value, int32_t **v) { +void infer_shape_set_value(int row, int col, int64_t value, int64_t **v) { v[row][col] = value; } void cinn_gpu_cudnn_softmax(const std::vector &attrs, diff --git a/paddle/cinn/runtime/cuda/cuda_util.h b/paddle/cinn/runtime/cuda/cuda_util.h index c7d9220e00688..3e8a93ecce4a8 100644 --- a/paddle/cinn/runtime/cuda/cuda_util.h +++ b/paddle/cinn/runtime/cuda/cuda_util.h @@ -95,8 +95,8 @@ void cinn_call_cuda_memcpy(void* v_args, size_t count, void* stream = nullptr); -int32_t cinn_get_value_in_cuda_kernel_args(void* v_args, int idx); -void infer_shape_set_value(int row, int col, int32_t value, int32_t** v); +int64_t cinn_get_value_in_cuda_kernel_args(void* v_args, int idx); +void infer_shape_set_value(int row, int col, int64_t value, int64_t** v); /** * Call a CUDA compiled kernel. diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc index d8fd3db290b33..a88221bc23e8b 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc @@ -30,7 +30,7 @@ namespace paddle { namespace framework { typedef void (*lower_func_ptr_g)(void*, int32_t, void*); -typedef void (*infer_shape_func_ptr_g)(void*, int32_t, int32_t**); +typedef void (*infer_shape_func_ptr_g)(void*, int32_t, int64_t**); class CinnJitInstruction::FnPtrImpl { using CINNKernelInfo = cinn::hlir::framework::pir::CINNKernelInfo; @@ -81,18 +81,18 @@ class CinnJitInstruction::FnPtrImpl { } // 3. Define an array of Pointers to hold the output tensor shape - int32_t* output_tensor_shapes[output_tensor_size]; + std::vector output_tensor_shapes(output_tensor_size); for (int i = 0; i < output_tensor_size; ++i) { - output_tensor_shapes[i] = reinterpret_cast( + output_tensor_shapes[i] = reinterpret_cast( malloc(kernel_args[input_tensor_size + i]->dims().size() * - sizeof(int32_t*))); + sizeof(int64_t*))); } // 4. Launch infer_shape_fn_ptr to infer shape of output tensor ((infer_shape_func_ptr_g)cinn_kernel_info_.infer_shape_fn_ptr)( static_cast(func_args_.data()), func_args_.size(), - output_tensor_shapes); + output_tensor_shapes.data()); // 5. Resize shape of output tensor for (int i = 0; i < output_tensor_size; ++i) {