Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:gouzil/Paddle into noqa_clean_17
Browse files Browse the repository at this point in the history
# Conflicts:
#	python/paddle/distributed/__init__.py
#	python/paddle/distributed/auto_parallel/api.py
  • Loading branch information
gouzil committed Dec 28, 2023
2 parents 23d0398 + fdc38b2 commit fff16b3
Show file tree
Hide file tree
Showing 326 changed files with 9,486 additions and 3,817 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ paddle/fluid/pir/dialect/operator/ir/pd_api.*
paddle/fluid/pir/dialect/operator/ir/op_decomp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.*
paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused_bwd.*
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ if(NOT DEFINED XPU_BASE_DATE)
set(XPU_BASE_DATE "20231203")
endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "20231225")
set(XPU_XHPC_BASE_DATE "20231226")
endif()
set(XPU_XCCL_BASE_VERSION "1.1.8.1")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
5 changes: 4 additions & 1 deletion paddle/cinn/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ gather_srcs(
python_interpreter_guard.cc
nvgpu_dev_info.cc
integer_set.cc
dim_expr_simplify.cc)
dim_expr_simplify.cc
dim_expr_converter.cc)

cinn_cc_test(test_equation_graph_topo_walker SRCS
equation_graph_topo_walker_test.cc DEPS gtest glog)
Expand All @@ -49,4 +50,6 @@ endif()
if(NOT CINN_ONLY)
cinn_cc_test(dim_expr_simplify_test SRCS dim_expr_simplify_test.cc DEPS
cinncore)
cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS
cinncore)
endif()
101 changes: 101 additions & 0 deletions paddle/cinn/common/dim_expr_converter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/common/ir_util.h"

namespace cinn::common {
using namespace symbol; // NOLINT

namespace {

struct DimExprToIrExprVisitor {
ir::Expr ConvertToIrExpr(const DimExpr& dim_expr) {
return std::visit(*this, dim_expr.variant());
}

ir::Expr operator()(const int64_t& dim) { return ir::Expr(dim); }

ir::Expr operator()(const std::string& dim_expr) {
Var x = ir::_Var_::Make(dim_expr, Int(64));
return x;
}

ir::Expr operator()(const Negative<DimExpr>& dim_expr) {
const auto& [operand] = *dim_expr;
return ir::Sub::Make(ir::Expr(std::int64_t(0)), ConvertToIrExpr(operand));
}

ir::Expr operator()(const Reciprocal<DimExpr>& dim_expr) {
const auto& [operand] = *dim_expr;
return ir::Div::Make(ir::Expr(std::int64_t(1)), ConvertToIrExpr(operand));
}

ir::Expr operator()(const Add<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
if (operands->empty()) {
return ir::Expr(std::int64_t(0));
}
ir::Expr sum = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
sum = ir::Add::Make(sum, ConvertToIrExpr(operands->at(i)));
}
return sum;
}

ir::Expr operator()(const Mul<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
if (operands->empty()) {
return ir::Expr(std::int64_t(1));
}
ir::Expr product = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
product = ir::Mul::Make(product, ConvertToIrExpr(operands->at(i)));
}
return product;
}

ir::Expr operator()(const Max<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
CHECK(!operands->empty());
ir::Expr max = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
max = ir::Max::Make(max, ConvertToIrExpr(operands->at(i)));
}
return max;
}

ir::Expr operator()(const Min<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
CHECK(!operands->empty());
ir::Expr min = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
min = ir::Min::Make(min, ConvertToIrExpr(operands->at(i)));
}
return min;
}

ir::Expr operator()(const Broadcast<DimExpr>& dim_expr) {
LOG(FATAL)
<< "no support for converting from Broadcast<DimExpr> to ir::Expr";
}
};

} // namespace

ir::Expr DimExprConverter::ConvertToIrExpr(const DimExpr& dim_expr) const {
return DimExprToIrExprVisitor().ConvertToIrExpr(dim_expr);
}

} // namespace cinn::common
26 changes: 26 additions & 0 deletions paddle/cinn/common/dim_expr_converter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/ir/ir.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"

namespace cinn::common {

struct DimExprConverter final {
ir::Expr ConvertToIrExpr(const symbol::DimExpr&) const;
};

} // namespace cinn::common
79 changes: 79 additions & 0 deletions paddle/cinn/common/dim_expr_converter_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <sstream>

#include "gtest/gtest.h"

#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_printer.h"

namespace cinn::common::test {

using namespace symbol; // NOLINT

TEST(Convert, AddExpr) {
List<DimExpr> num_lists{DimExpr(4), DimExpr(5), DimExpr("sym_0")};
DimExpr dim_expr{Add<DimExpr>{num_lists}};
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

ir::Expr expr1 =
ir::Add::Make(ir::Expr(std::int64_t(4)), ir::Expr(std::int64_t(5)));
ir::Expr dst_expr = ir::Add::Make(expr1, ir::_Var_::Make("sym_0", Int(64)));
ASSERT_TRUE(MathEqual(src_expr, dst_expr));
}

TEST(Convert, SubExpr) {
DimExpr dim_expr = DimExpr(4) - DimExpr("sym_0");
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

ir::Expr expr1 = ir::Sub::Make(ir::Expr(std::int64_t(0)),
ir::_Var_::Make("sym_0", Int(64)));
ir::Expr dst_expr = ir::Add::Make(ir::Expr(std::int64_t(4)), expr1);
ASSERT_TRUE(MathEqual(src_expr, dst_expr));
}

TEST(Convert, MulExpr) {
List<DimExpr> num_lists{DimExpr(4), DimExpr(5), DimExpr("sym_0")};
DimExpr dim_expr{Mul<DimExpr>{num_lists}};
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

ir::Expr expr1 =
ir::Mul::Make(ir::Expr(std::int64_t(4)), ir::Expr(std::int64_t(5)));
ir::Expr dst_expr = ir::Mul::Make(expr1, ir::_Var_::Make("sym_0", Int(64)));
ASSERT_TRUE(MathEqual(src_expr, dst_expr));
}

TEST(Convert, MaxExpr) {
List<DimExpr> num_lists{DimExpr(4), DimExpr(5), DimExpr("sym_0")};
DimExpr dim_expr{Max<DimExpr>{num_lists}};
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

std::ostringstream stream;
stream << src_expr;
ASSERT_EQ(stream.str(), "cinn_max(cinn_max(4ll, 5ll), sym_0)");
}

TEST(Convert, MinExpr) {
List<DimExpr> num_lists{DimExpr(4), DimExpr(5), DimExpr("sym_0")};
DimExpr dim_expr{Min<DimExpr>{num_lists}};
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

std::ostringstream stream;
stream << src_expr;
ASSERT_EQ(stream.str(), "cinn_min(cinn_min(4ll, 5ll), sym_0)");
}

} // namespace cinn::common::test
1 change: 1 addition & 0 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
func : SliceRawInferMeta
kernel :
func : slice
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : uniform_random
args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0)
Expand Down
40 changes: 20 additions & 20 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ namespace cinn {
namespace dialect {
namespace ir {

class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
class SumOpPattern : public paddle::drr::DrrPatternBase<SumOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand All @@ -48,7 +48,7 @@ class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
pattern.Tensor("ret") = sum(pattern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_sum =
res.Op(cinn::dialect::ReduceSumOp::name(),
{{"dim", pattern.Attr("axis_info")},
Expand All @@ -57,11 +57,11 @@ class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
}
};

class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
class MaxOpPattern : public paddle::drr::DrrPatternBase<MaxOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand All @@ -73,7 +73,7 @@ class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_max =
res.Op(cinn::dialect::ReduceMaxOp::name(),
{{"dim", pattern.Attr("axis_info")},
Expand All @@ -82,11 +82,11 @@ class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
}
};

class MinOpPattern : public pir::drr::DrrPatternBase<MinOpPattern> {
class MinOpPattern : public paddle::drr::DrrPatternBase<MinOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand All @@ -98,7 +98,7 @@ class MinOpPattern : public pir::drr::DrrPatternBase<MinOpPattern> {
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_max =
res.Op(cinn::dialect::ReduceMinOp::name(),
{{"dim", pattern.Attr("axis_info")},
Expand All @@ -107,11 +107,11 @@ class MinOpPattern : public pir::drr::DrrPatternBase<MinOpPattern> {
}
};

class ProdOpPattern : public pir::drr::DrrPatternBase<ProdOpPattern> {
class ProdOpPattern : public paddle::drr::DrrPatternBase<ProdOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand All @@ -123,7 +123,7 @@ class ProdOpPattern : public pir::drr::DrrPatternBase<ProdOpPattern> {
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_max =
res.Op(cinn::dialect::ReduceProdOp::name(),
{{"dim", pattern.Attr("axis_info")},
Expand Down Expand Up @@ -552,11 +552,11 @@ class SplitWithNumOpPattern
}
};

class UniformOpPattern : public pir::drr::DrrPatternBase<UniformOpPattern> {
class UniformOpPattern : public paddle::drr::DrrPatternBase<UniformOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand Down Expand Up @@ -585,7 +585,7 @@ class UniformOpPattern : public pir::drr::DrrPatternBase<UniformOpPattern> {
// int64_t[] shape, float min, float max, int seed, DataType dtype, int
// diag_num, int diag_step, float diag_val)
// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_uniform =
res.Op(cinn::dialect::UniformRandomOp::name(),
{{"shape", pattern.Attr("axis_info")},
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/ir/schedule/impl/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ std::vector<Expr> DyScheduleImpl::Split(const Expr& loop,
for (auto factor : factors) prod_size = prod_size * Expr(factor);
std::for_each(factors.begin(), factors.end(), [&](int factor) {
if (factor == -1) {
process_factors.push_back(tot_extent / prod_size + Expr(1));
process_factors.push_back(
cinn::common::AutoSimplify(tot_extent / prod_size + Expr(1)));
} else {
process_factors.push_back(Expr(factor));
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/eager/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ endif()
cc_library(
eager_nan_inf_utils
SRCS nan_inf_utils.cc
DEPS phi common nan_inf_utils enforce)
DEPS phi common enforce)
cc_library(
grad_node_info
SRCS grad_node_info.cc
Expand Down
Loading

0 comments on commit fff16b3

Please sign in to comment.