Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… lstm
  • Loading branch information
Asthestarsfalll committed Apr 12, 2024
2 parents b683660 + 22fb3b1 commit 8280767
Show file tree
Hide file tree
Showing 693 changed files with 11,969 additions and 5,426 deletions.
1 change: 1 addition & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,4 @@
[submodule "third_party/nlohmann_json"]
path = third_party/nlohmann_json
url = https://github.com/nlohmann/json.git
ignore = dirty
5 changes: 5 additions & 0 deletions cmake/flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ if(NOT WIN32)
-Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array
)

if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
set(GPU_COMMON_FLAGS -ccbin=${CMAKE_CXX_COMPILER} ${GPU_COMMON_FLAGS})
endif()

if(NOT WITH_NV_JETSON
AND NOT WITH_ARM
AND NOT WITH_SW
Expand Down
32 changes: 26 additions & 6 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ pir::Block* GroupOp::block() {

pir::Block* GroupOp::block() const {
pir::Region& region = (*this)->region(0);
CHECK(!region.empty());
PADDLE_ENFORCE_EQ(region.empty(),
false,
::common::errors::Unavailable(
"Required GroupOp's region must not be emptpy."));
return &region.front();
}

Expand Down Expand Up @@ -155,7 +158,16 @@ pir::Block* FusionOp::block() {
return &region.front();
}

std::vector<pir::Operation*> FusionOp::GetOperators() {
pir::Block* FusionOp::block() const {
pir::Region& region = (*this)->region(0);
PADDLE_ENFORCE_EQ(region.empty(),
false,
::common::errors::Unavailable(
"Required FusionOp's region must not be emptpy."));
return &region.front();
}

std::vector<pir::Operation*> FusionOp::GetOperators() const {
std::vector<pir::Operation*> rt_ops;
for (auto& op : *block()) {
rt_ops.push_back(&op);
Expand Down Expand Up @@ -304,7 +316,9 @@ void GenerateShapeOp::Build(
if (inputs.empty()) {
VLOG(3) << "GenerateShapeOp inputs is empty";
for (const auto& attr : output_dim_exprs) {
CHECK(attr.isa<pir::Int64Attribute>());
PADDLE_ENFORCE(attr.isa<pir::Int64Attribute>(),
::common::errors::PreconditionNotMet(
"Reqiured attr must be Int64Attribute."));
}
}
argument.AddInputs(inputs);
Expand Down Expand Up @@ -466,11 +480,15 @@ bool GenerateShapeOp::InferSymbolicShape(
const auto attr_dim_exprs = [&] {
std::vector<symbol::DimExpr> dim_exprs{};
pir::Attribute dim_expr_attr = this->attributes().at("output_dim_exprs");
CHECK(dim_expr_attr.isa<pir::ArrayAttribute>());
PADDLE_ENFORCE(dim_expr_attr.isa<pir::ArrayAttribute>(),
::common::errors::PreconditionNotMet(
"Required dim_expr_attr is ArrayAttribute."));
auto array = dim_expr_attr.dyn_cast<pir::ArrayAttribute>();
for (int i = 0; i < array.size(); ++i) {
const auto& dim_expr = ConvertAttributeToDimExpr(array.at(i));
CHECK(dim_expr.has_value());
PADDLE_ENFORCE(dim_expr.has_value(),
::common::errors::PreconditionNotMet(
"Required dim_expr.has_value()==true."));
dim_exprs.push_back(dim_expr.value());
}
return dim_exprs;
Expand All @@ -480,7 +498,9 @@ bool GenerateShapeOp::InferSymbolicShape(
this->attributes().at("symbol_bindings");
auto symbol_bindings =
ConvertAttributeToSymbolBindings(symbol_bindings_attr);
CHECK(symbol_bindings.has_value());
PADDLE_ENFORCE(symbol_bindings.has_value(),
::common::errors::PreconditionNotMet(
"Required symbol_bindings.has_value()==true."));
return symbol_bindings.value();
}();
auto DimExprs4InputDim =
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class IR_API FusionOp : public pir::Op<FusionOp> {
const cinn::dialect::GroupInfo &group_info);

pir::Block *block();
std::vector<pir::Operation *> GetOperators();
pir::Block *block() const;

std::vector<pir::Operation *> GetOperators() const;

void VerifySig();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS
cc_library(
add_cinn_pass
SRCS add_cinn_pass.cc
DEPS op_dialect pir cinn_op_dialect cinnapi pir_transforms cinn_transforms)
DEPS pir_transforms cinn_transforms)
Original file line number Diff line number Diff line change
Expand Up @@ -49,56 +49,6 @@ void VisitEachValue(const pir::Operation* op, const DoEachT& DoEach) {
}
}

symbol::TensorShapeOrDataDimExprs SubstituteTensorShapeOrData(
const symbol::TensorShapeOrDataDimExprs& shape_or_data,
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>&
substitution_pattern) {
auto SubstituteOneDimExpr =
[](const std::vector<symbol::DimExpr>& original_dim_expr,
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>&
substitution_pattern) -> std::vector<symbol::DimExpr> {
std::vector<symbol::DimExpr> substituted_dim_expr{};
for (const symbol::DimExpr& dim_expr : original_dim_expr) {
const auto& tmp_dim_expr =
symbol::SubstituteDimExpr(dim_expr, substitution_pattern);
substituted_dim_expr.push_back(symbol::SimplifyDimExpr(tmp_dim_expr));
}
return substituted_dim_expr;
};

std::vector<symbol::DimExpr> substituted_shape =
SubstituteOneDimExpr(shape_or_data.shape(), substitution_pattern);
if (!shape_or_data.data().has_value()) {
return symbol::ShapeOrData<symbol::DimExpr>(substituted_shape);
} else {
std::vector<symbol::DimExpr> substituted_data = SubstituteOneDimExpr(
shape_or_data.data().value(), substitution_pattern);
return symbol::ShapeOrData<symbol::DimExpr>(substituted_shape,
substituted_data);
}
}

symbol::ShapeOrDataDimExprs SubstituteShapeOrData(
const symbol::ShapeOrDataDimExprs& shape_or_data,
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>&
substitution_pattern) {
auto lambdas = symbol::Overloaded{
[&](const symbol::TensorShapeOrDataDimExprs& tensor_shape_or_data) {
return symbol::ShapeOrDataDimExprs(SubstituteTensorShapeOrData(
tensor_shape_or_data, substitution_pattern));
},
[&](const symbol::TensorListShapeOrDataDimExprs& tensor_list) {
symbol::TensorListShapeOrDataDimExprs substituted_tensor_list;
for (symbol::TensorShapeOrDataDimExprs tensor_shape_or_data :
tensor_list) {
substituted_tensor_list.push_back(SubstituteTensorShapeOrData(
tensor_shape_or_data, substitution_pattern));
}
return symbol::ShapeOrDataDimExprs(substituted_tensor_list);
}};
return std::visit(lambdas, shape_or_data.variant());
}

std::unordered_map<symbol::DimExpr, symbol::DimExpr> GetDimExprSubstitution(
pir::ShapeConstraintIRAnalysis* shape_analysis) {
const std::vector<symbol::DimExprConstraint>& dim_expr_constraints =
Expand Down Expand Up @@ -155,7 +105,8 @@ void SubstituteDimExprBasedOnConstraints(pir::Operation* region_op) {
VLOG(8) << op->name()
<< " origin_shape_or_data: " << origin_shape_or_data;
const symbol::ShapeOrDataDimExprs& substituted_shape_or_data =
SubstituteShapeOrData(origin_shape_or_data, substitution_pattern);
symbol::SubstituteShapeOrData(origin_shape_or_data,
substitution_pattern);
VLOG(8) << op->name()
<< " substituted_shape_or_data: " << substituted_shape_or_data;
shape_analysis->SetShapeOrDataForValue(value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,58 +136,13 @@ bool IsShapeOrDataNeedSubstitute(
return ret;
}

symbol::TensorShapeOrDataDimExprs SubstituteTensorShapeOrData(
const symbol::TensorShapeOrDataDimExprs& shape_or_data,
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>& dim_expr_map) {
const auto& SimplifyDimExpr =
[&](const std::vector<symbol::DimExpr>& original_dim_expr)
-> std::vector<symbol::DimExpr> {
std::vector<symbol::DimExpr> simplified_dim_expr{};
for (const symbol::DimExpr& dim_expr : original_dim_expr) {
simplified_dim_expr.push_back(symbol::SimplifyDimExpr(
symbol::SubstituteDimExpr(dim_expr, dim_expr_map)));
}
return simplified_dim_expr;
};

std::vector<symbol::DimExpr> simplified_shape =
SimplifyDimExpr(shape_or_data.shape());
if (!shape_or_data.data().has_value()) {
return symbol::ShapeOrData<symbol::DimExpr>(simplified_shape);
}
std::vector<symbol::DimExpr> simplified_data =
SimplifyDimExpr(shape_or_data.data().value());
return symbol::ShapeOrData<symbol::DimExpr>(simplified_shape,
simplified_data);
}

symbol::ShapeOrDataDimExprs SubstituteShapeOrData(
const symbol::ShapeOrDataDimExprs& shape_or_data,
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>& dim_expr_map) {
auto lambdas = symbol::Overloaded{
[&](const symbol::TensorShapeOrDataDimExprs& tensor_shape_or_data) {
return symbol::ShapeOrDataDimExprs(
SubstituteTensorShapeOrData(tensor_shape_or_data, dim_expr_map));
},
[&](const symbol::TensorListShapeOrDataDimExprs& tensor_list) {
symbol::TensorListShapeOrDataDimExprs simplified_tensor_list;
for (symbol::TensorShapeOrDataDimExprs tensor_shape_or_data :
tensor_list) {
simplified_tensor_list.push_back(
SubstituteTensorShapeOrData(tensor_shape_or_data, dim_expr_map));
}
return symbol::ShapeOrDataDimExprs(simplified_tensor_list);
}};
return std::visit(lambdas, shape_or_data.variant());
}

symbol::ShapeOrDataDimExprs TrySubstitute(
const symbol::ShapeOrDataDimExprs& shape_or_data,
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>& dim_expr_map) {
if (!IsShapeOrDataNeedSubstitute(shape_or_data, dim_expr_map)) {
return shape_or_data;
}
return SubstituteShapeOrData(shape_or_data, dim_expr_map);
return symbol::SubstituteShapeOrData(shape_or_data, dim_expr_map);
}

void InferSymbolicShapeForOperation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ CompileGroupAsOpAttribute(const std::vector<OpLoweringGroupPtr>& group_list) {

std::unordered_map<std::string, ::pir::Attribute> GetJitKernelAttr(
const OpLoweringGroupPtr& group) {
auto kernel_info = CompilationCache::Instance().GetKernelInfo(group);
hlir::framework::pir::FusionInfo fusion_info(*group);
auto kernel_info = CompilationCache::Instance().GetKernelInfo(fusion_info);
std::unordered_map<std::string, ::pir::Attribute> attrs{
{cinn::dialect::JitKernelOp::kAttrName,
cinn::dialect::CINNKernelInfoAttribute::get(pir::IrContext::Instance(),
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/framework/pir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ gather_srcs(
trivial_op_impl.cc
trivial_op_util.cc
compilation_task.cc
compilation_cache.cc)
compilation_cache.cc
fusion_info.cc)
43 changes: 10 additions & 33 deletions paddle/cinn/hlir/framework/pir/compilation_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,38 +39,20 @@ void* BackendResource::GetInferFuncPtr() const {
return ptr;
}

std::shared_ptr<backends::Compiler>& BackendResource::GetBackendCompiler() {
return backend_compiler_;
}

const std::shared_ptr<backends::Compiler>& BackendResource::GetBackendCompiler()
const {
return backend_compiler_;
}

void BackendResource::SetHostFnName(const std::string& name) {
host_fn_name_ = name;
}

void BackendResource::SetInferFnName(const std::string& name) {
infer_fn_name_ = name;
}

pir::CINNKernelInfo BackendResource::GernerateKernelInfo(
const std::shared_ptr<pir::OpLoweringGroup>& group) const {
pir::CINNKernelInfo BackendResource::GenerateKernelInfo() const {
pir::CINNKernelInfo kernel_info;
kernel_info.fn_name = host_fn_name_;
kernel_info.fn_ptr = GetHostFuncPtr();
kernel_info.infer_shape_fn_ptr = GetInferFuncPtr();
kernel_info.int_args_map = group->int_args_map();
kernel_info.int_args_map = GetIntArgsMap();
return kernel_info;
}
} // namespace pir

bool CompilationCache::Has(const CacheKey& key) const {
const bool has_existed = cache_.find(KeyHash(key)) != cache_.end();
VLOG(6) << "Check IsExisted in CompilationCache: " << key->FuncName() << " "
<< has_existed;
const bool has_existed = cache_.find(key) != cache_.end();
VLOG(6) << "Check IsExisted in CompilationCache: " << has_existed << " - "
<< key;
return has_existed;
}

Expand All @@ -79,24 +61,19 @@ const CompilationCache::CacheValue& CompilationCache::Get(
PADDLE_ENFORCE_EQ(
Has(key),
true,
phi::errors::NotFound("%s is not in CompliatonCache.", key->FuncName()));
return cache_.at(KeyHash(key));
phi::errors::NotFound("%s is not in CompliatonCache.", key));
return cache_.at(key);
}

pir::CINNKernelInfo CompilationCache::GetKernelInfo(const CacheKey& key) const {
return Get(key)->GetKernelInfo(key);
return Get(key)->GetKernelInfo();
}

void CompilationCache::Insert(const CacheKey& key, const CacheValue& value) {
VLOG(6) << "Insert CompilationCache for: " << key->FuncName();
cache_.insert({KeyHash(key), value});
VLOG(6) << "Insert CompilationCache for: " << key;
cache_.insert({key, value});
}

void CompilationCache::Clear() { cache_.clear(); }

size_t CompilationCache::KeyHash(const CacheKey& key) const {
// TODO(Aurelius84): use a better hash function in next pr.
return std::hash<std::string>{}(key->FuncName());
}

} // namespace cinn::hlir::framework
Loading

0 comments on commit 8280767

Please sign in to comment.