Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghuancoder committed Feb 2, 2024
2 parents 58391e1 + 936324f commit eaf7d9d
Show file tree
Hide file tree
Showing 342 changed files with 14,848 additions and 3,219 deletions.
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ if(NOT DEFINED XPU_BASE_DATE)
set(XPU_BASE_DATE "20240104")
endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "20240125")
set(XPU_XHPC_BASE_DATE "20240129")
endif()
set(XPU_XCCL_BASE_VERSION "1.1.8.1")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
2 changes: 1 addition & 1 deletion paddle/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ add_subdirectory(fluid)
# (2) calling `cc_test()` in each `CMakeLists.txt` will not `exactly` add test, but
# record all tests and its source files, the action of add tests is defered to HERE.
# Why doing so? since the target of `libpaddle.so` is mostly the last target, and
# the tests should be added after that accroding to dependency.
# the tests should be added after that according to dependency.
# (3) the tests links dynamic libraries, `libpaddle.so`
# (4) the tests are generated to the same directory, i.e., `CC_TESTS_DIR` defined above.

Expand Down
11 changes: 7 additions & 4 deletions paddle/cinn/backends/codegen_cuda_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ llvm::Value* CodeGenCUDA_Host::LowerCUDAKernelCall(const ir::Call* call_ir) {
args_type.push_back(CinnTypeToLLVMType(type_of<void*>(), m_));
} else if (r_arg.as_var()->type().is_int(32)) {
args_type.push_back(CinnTypeToLLVMType(type_of<int32_t>(), m_));
} else if (r_arg.as_var()->type().is_int(64)) {
args_type.push_back(CinnTypeToLLVMType(type_of<int64_t>(), m_));
} else {
CINN_NOT_IMPLEMENTED;
}
Expand Down Expand Up @@ -316,10 +318,11 @@ llvm::Value* CodeGenCUDA_Host::LowerCUDAKernelCall(const ir::Call* call_ir) {
b_->getInt8PtrTy());
call_args.push_back(b_->CreateLoad(
b_->getInt8PtrTy(), kvalue, r_arg.as_var()->name + "_ptr_load"));
} else if (r_arg.as_var()->type().is_cpp_handle() ||
r_arg.as_var()->type().is_int(32)) {
} else if (r_arg.as_var()->type().is_cpp_handle()) {
CHECK(global_args.count(r_arg.as_var()->name));
call_args.push_back(global_args[r_arg.as_var()->name]);
} else if (r_arg.as_var()->type().is_int()) {
call_args.push_back(GetVar(r_arg.as_var()->name, false));
} else {
CINN_NOT_IMPLEMENTED;
}
Expand All @@ -331,9 +334,9 @@ llvm::Value* CodeGenCUDA_Host::LowerCUDAKernelCall(const ir::Call* call_ir) {
} else if (r_arg.type().is_int(16)) {
call_args.push_back(b_->getInt16(r_arg.as_int16()));
} else if (r_arg.type().is_int(32)) {
call_args.push_back(b_->getInt32(r_arg.as_int32()));
call_args.push_back(CodeGenLLVM::Visit(&r_arg));
} else if (r_arg.type().is_int(64)) {
call_args.push_back(b_->getInt64(r_arg.as_int64()));
call_args.push_back(CodeGenLLVM::Visit(&r_arg));
} else if (r_arg.type().is_uint(8)) {
call_args.push_back(b_->getInt8(r_arg.as_uint8()));
} else if (r_arg.type().is_uint(16)) {
Expand Down
7 changes: 5 additions & 2 deletions paddle/cinn/common/dim_expr_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ struct DimExprToIrExprVisitor {
ir::Expr operator()(const int64_t& dim) { return ir::Expr(dim); }

ir::Expr operator()(const std::string& dim_expr) {
Var x = ir::_Var_::Make(ir::Expr(static_cast<int64_t>(0)),
ir::Expr(INT64_MAX),
// The dimension must be greater equal than 1, and due to the extensive use
// of int32 in CAS, the upper bound here is temporarily INT32_MAX, otherwise
// there may be a risk of overflow.
Var x = ir::_Var_::Make(ir::Expr(static_cast<int64_t>(1)),
ir::Expr(INT32_MAX),
dim_expr,
/* is_reduce = */ false,
/* is_symbolic_constant = */ true);
Expand Down
12 changes: 6 additions & 6 deletions paddle/cinn/common/dim_expr_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ TEST(Convert, AddExpr) {
ir::Add::Make(ir::Expr(std::int64_t(4)), ir::Expr(std::int64_t(5)));
ir::Expr dst_expr =
ir::Add::Make(expr1,
ir::_Var_::Make(ir::Expr(static_cast<int64_t>(0)),
ir::Expr(INT64_MAX),
ir::_Var_::Make(ir::Expr(static_cast<int64_t>(1)),
ir::Expr(INT32_MAX),
"sym_0",
/* is_reduce = */ false,
/* is_symbolic_constant = */ true));
Expand All @@ -47,8 +47,8 @@ TEST(Convert, SubExpr) {

ir::Expr expr1 =
ir::Sub::Make(ir::Expr(std::int64_t(0)),
ir::_Var_::Make(ir::Expr(static_cast<int64_t>(0)),
ir::Expr(INT64_MAX),
ir::_Var_::Make(ir::Expr(static_cast<int64_t>(1)),
ir::Expr(INT32_MAX),
"sym_0",
/* is_reduce = */ false,
/* is_symbolic_constant = */ true));
Expand All @@ -65,8 +65,8 @@ TEST(Convert, MulExpr) {
ir::Mul::Make(ir::Expr(std::int64_t(4)), ir::Expr(std::int64_t(5)));
ir::Expr dst_expr =
ir::Mul::Make(expr1,
ir::_Var_::Make(ir::Expr(static_cast<int64_t>(0)),
ir::Expr(INT64_MAX),
ir::_Var_::Make(ir::Expr(static_cast<int64_t>(1)),
ir::Expr(INT32_MAX),
"sym_0",
/* is_reduce = */ false,
/* is_symbolic_constant = */ true));
Expand Down
9 changes: 5 additions & 4 deletions paddle/cinn/frontend/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ namespace frontend {
*/
class Interpreter final {
public:
Interpreter(const std::vector<std::string>& input_names,
const std::vector<hlir::framework::shape_t>& input_shapes);
TEST_API Interpreter(
const std::vector<std::string>& input_names,
const std::vector<hlir::framework::shape_t>& input_shapes);

/**
* Load a Paddle model.
Expand All @@ -49,15 +50,15 @@ class Interpreter final {
/**
* Run the executor.
*/
void Run();
TEST_API void Run();

frontend::Program GetProgram();

hlir::framework::Tensor GetTensor(const std::string& name);

std::shared_ptr<hlir::framework::Scope> GetScope();

~Interpreter();
TEST_API ~Interpreter();

private:
class Impl;
Expand Down
6 changes: 6 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ void FusionOp::Print(pir::IrPrinter& printer) {
os << " \n }";
}

bool ConcatOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis* shape_analysis) {
VLOG(4) << "Infer symbolic shape for cinn_op.concat";
return ConcatOpInferSymbolicShape(this->operation(), shape_analysis);
}

void ConcatOp::Build(pir::Builder& builder, // NOLINT
pir::OperationArgument& argument, // NOLINT
const std::vector<pir::Value>& inputs,
Expand Down
5 changes: 4 additions & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class IR_API FusionOp : public pir::Op<FusionOp> {
void Print(pir::IrPrinter &printer); // NOLINT
};

class IR_API ConcatOp : public pir::Op<ConcatOp> {
class IR_API ConcatOp
: public pir::Op<ConcatOp, paddle::dialect::InferSymbolicShapeInterface> {
public:
using Op::Op;

Expand All @@ -83,6 +84,8 @@ class IR_API ConcatOp : public pir::Op<ConcatOp> {
int axis);

void VerifySig() const {}

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
};

class IR_API SplitOp : public pir::Op<SplitOp> {
Expand Down
16 changes: 10 additions & 6 deletions paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,19 @@ void OperatorDialect::PrintAttribute(pir::Attribute attr,
}
}

void OperatorDialect::PrintOperation(pir::Operation *op,
pir::IrPrinter &printer) const {
pir::OpPrintFn OperatorDialect::PrintOperation(pir::Operation *op) const {
if (auto group_op = op->dyn_cast<GroupOp>()) {
group_op.Print(printer);
return [](pir::Operation *op, pir::IrPrinter &printer) {
auto group_op = op->dyn_cast<GroupOp>();
group_op.Print(printer);
};
} else if (auto fusion_op = op->dyn_cast<FusionOp>()) {
fusion_op.Print(printer);
} else {
printer.PrintGeneralOperation(op);
return [](pir::Operation *op, pir::IrPrinter &printer) {
auto fusion_op = op->dyn_cast<FusionOp>();
fusion_op.Print(printer);
};
}
return nullptr;
}

} // namespace dialect
Expand Down
3 changes: 1 addition & 2 deletions paddle/cinn/hlir/dialect/operator/ir/op_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ class OperatorDialect : public ::pir::Dialect {

void PrintType(pir::Type type, std::ostream& os) const override;
void PrintAttribute(pir::Attribute type, std::ostream& os) const override;
void PrintOperation(pir::Operation* op,
pir::IrPrinter& printer) const override; // NOLINT
pir::OpPrintFn PrintOperation(pir::Operation* op) const override; // NOLINT

private:
void initialize();
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
func : ReshapeInferMeta
kernel :
func : reshape
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : scale
args : (Tensor x, float scale=1.0, float bias=0.0, bool bias_after_scale=true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"
#include "paddle/pir/dialect/shape/utils/shape_analysis.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
Expand Down Expand Up @@ -58,14 +59,24 @@ std::vector<pir::Value> FindSourceDenseTensorOfDimTensor(
Visit(owner->operand_source(i));
}
};
const auto& IsDimTensor = [&](pir::Value value) -> bool {
return ShapeOrDataDimExprs4Value(value).data().has_value();
const auto& IsDimTensorOrListDimExpr = symbol::Overloaded{
[](const symbol::TensorShapeOrDataDimExprs& dim_expr) {
return dim_expr.data().has_value();
},
[](const symbol::TensorListShapeOrDataDimExprs& dim_expr) {
return true;
}};
// For TensorListShapeOrDataDimExprs case, we should recursivly visit its
// each dim_expr, which is automatically in next step.
const auto& NeedTrackUpstream = [&](pir::Value value) -> bool {
const auto& sym_shape = ShapeOrDataDimExprs4Value(value);
return std::visit(IsDimTensorOrListDimExpr, sym_shape.variant());
};
const auto& ForEachInputDimTensor =
[&](pir::Value value, const std::function<void(pir::Value)>& Visit) {
// find input dimension tensor;
ForEachInputValue(value, [&](pir::Value input) {
if (IsDimTensor(input)) {
if (NeedTrackUpstream(input)) {
Visit(input);
}
});
Expand All @@ -75,7 +86,7 @@ std::vector<pir::Value> FindSourceDenseTensorOfDimTensor(
size_t input_cnt = 0;
ForEachInputValue(value, [&](pir::Value input) {
++input_cnt;
if (IsDimTensor(input)) return;
if (NeedTrackUpstream(input)) return;
Emplace(input);
});
if (input_cnt == 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/check_infer_symbolic_pass.h"

#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"
#include "paddle/utils/flags.h"

namespace cinn {
namespace dialect {
namespace ir {

namespace {

std::string SprintShape(const std::vector<std::int64_t>& shape) {
std::string str = "[";
for (std::int64_t value : shape) {
str += std::to_string(value);
if (value != shape.back()) {
str += ", ";
}
}
return str + "]";
}

void PrintProgram(pir::ModuleOp m, const std::string& mgs) {
std::ostringstream print_stream;
print_stream << "\n\n";
m.program()->Print(print_stream);
print_stream << "\n\n";
VLOG(4) << "===================== " << mgs << " =====================\n"
<< print_stream.str();
}

std::vector<std::int64_t> GetStaticValueShape(pir::Value value) {
const auto& dim = value.type().dyn_cast<::pir::DenseTensorType>().dims();
return ::common::vectorize(dim);
}

std::optional<std::vector<std::int64_t>> GetDynamicValueShape(
pir::Value value, const pir::ShapeConstraintIRAnalysis& shape_analysis) {
if (!shape_analysis.HasShapeOrDataForValue(value)) {
return std::nullopt;
}
const auto& dim_expr_dynamic_shapes =
shape_analysis.GetShapeOrDataForValue(value).shape();
std::vector<std::int64_t> dynamic_shapes{};
for (const auto& dim_expr_shape : dim_expr_dynamic_shapes) {
CHECK(dim_expr_shape.Has<std::int64_t>());
dynamic_shapes.push_back(dim_expr_shape.Get<std::int64_t>());
}
return dynamic_shapes;
}

void CompareStaticAndDynamicValueShape(
pir::Value value,
const pir::ShapeConstraintIRAnalysis& shape_analysis,
int op_index,
pir::ModuleOp module_op) {
std::vector<std::int64_t> static_value_shape = GetStaticValueShape(value);
std::optional<std::vector<std::int64_t>> opt_dynamic_value_shape =
GetDynamicValueShape(value, shape_analysis);
if (opt_dynamic_value_shape.has_value()) {
if (static_value_shape != opt_dynamic_value_shape.value()) {
VLOG(4) << "CheckInferSymbolic failed, in the fellowing program, the "
<< op_index
<< "th op : the shape is not equal\nthe static shape is: "
<< SprintShape(static_value_shape)
<< ", and the dynamic shape is: "
<< SprintShape(opt_dynamic_value_shape.value());
PrintProgram(module_op, "CheckInferSymbolic");
}
} else {
VLOG(4) << "CheckInferSymbolic failed, in the fellowing program, the "
<< op_index << "th op infer symbolic failed";
PrintProgram(module_op, "CheckInferSymbolic");
}
}

void CheckInferSymbolic(pir::ModuleOp module_op) {
VLOG(4) << "CheckInferSymbolic start";
int op_index = 0;
const auto& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(module_op.program());
for (uint32_t i = 0; i < module_op->num_regions(); i++) {
for (const auto& block : module_op->region(i)) {
for (const auto& op : block) {
for (std::size_t j = 0; j < op.num_operands(); ++j) {
CompareStaticAndDynamicValueShape(
op.operand_source(j), shape_analysis, op_index, module_op);
}
for (std::size_t j = 0; j < op.num_results(); ++j) {
CompareStaticAndDynamicValueShape(
op.result(j), shape_analysis, op_index, module_op);
}
op_index++;
}
}
}
VLOG(4) << "CheckInferSymbolic end";
}

class CheckInferSymbolicPass : public pir::Pass {
public:
CheckInferSymbolicPass() : pir::Pass("check_infer_symbolic_pass", 1) {}

void Run(pir::Operation* op) override {
pir::ModuleOp module_op = op->dyn_cast<pir::ModuleOp>();
CheckInferSymbolic(module_op);
}

bool CanApplyOn(pir::Operation* op) const override {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
}
};

} // namespace

std::unique_ptr<::pir::Pass> CreateCheckInferSymbolicPass() {
return std::make_unique<CheckInferSymbolicPass>();
}

} // namespace ir
} // namespace dialect
} // namespace cinn
Loading

0 comments on commit eaf7d9d

Please sign in to comment.