Skip to content

Commit

Permalink
fix include for all pass
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Jan 15, 2024
1 parent bf06163 commit fe3aae1
Show file tree
Hide file tree
Showing 22 changed files with 66 additions and 78 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/pir/drr/include/drr_pattern_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <unordered_map>
#include <unordered_set>
#include <variant>
#include <vector>

#include "paddle/fluid/pir/drr/include/drr_match_context.h"

Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "paddle/fluid/pir/transforms/transform_general_functions.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h"

#include "paddle/common/ddim.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"

#include "paddle/common/ddim.h"

namespace {

class Conv2dAddActFusePattern
Expand Down
12 changes: 3 additions & 9 deletions paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"

#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/common/ddim.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"

namespace {

Expand Down
11 changes: 5 additions & 6 deletions paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,16 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h"

#include "paddle/common/ddim.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"

namespace {

class Conv2dBnFusePattern
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/fc_fuse_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

// add_grad + matmul_grad + add_ -> matmul + fused_liner_param_gard_add
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/place.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

Expand Down
6 changes: 1 addition & 5 deletions paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,13 @@
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/common/ddim.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

Expand Down
9 changes: 1 addition & 8 deletions paddle/fluid/pir/transforms/identity_op_clean_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,13 @@
// limitations under the License.

#include "paddle/fluid/pir/transforms/identity_op_clean_pass.h"
#include <memory>
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/common/ddim.h"

#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/pir/pattern_rewrite/pattern_match.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

Expand All @@ -39,24 +36,15 @@ class ReplaceFetchWithShadowOutputPattern
}
};

class ReplaceFetchWithShadowOutputPass : public pir::Pass {
class ReplaceFetchWithShadowOutputPass : public pir::PatternRewritePass {
public:
ReplaceFetchWithShadowOutputPass()
: pir::Pass("replace_fetch_with_shadow_output_pass", 0) {}
: pir::PatternRewritePass("replace_fetch_with_shadow_output_pass", 0) {}

bool Initialize(pir::IrContext* context) override {
pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
ps.Add<ReplaceFetchWithShadowOutputPattern>(context);
patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
}

void Run(pir::Operation* op) override {
pir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
auto [_, num_rewrites] = pir::ApplyPatternsGreedily(op, patterns_, cfg);
AddStatistics(num_rewrites);
return ps;
}

bool CanApplyOn(pir::Operation* op) const override {
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/visitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ void Walk(Operation *op,

template <WalkOrder Order = WalkOrder::PostOrder, typename FuncTy>
void Walk(Operation *op, FuncTy &&callback) {
return detail::Walk(op, callback, Order);
return Walk(op, callback, Order);
}

} // namespace detail

} // namespace pir
2 changes: 2 additions & 0 deletions paddle/pir/pass/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "paddle/pir/pass/pass_adaptor.h"
#include "paddle/pir/pass/pass_instrumentation.h"
#include "paddle/pir/pass/pass_manager.h"
#include "paddle/pir/pattern_rewrite/pattern_match.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace pir {

Expand Down
5 changes: 2 additions & 3 deletions paddle/pir/pass/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
#include <vector>

#include "paddle/common/enforce.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/pass/analysis_manager.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"

namespace pir {

Expand Down Expand Up @@ -197,7 +196,7 @@ class IR_API Pass {
std::unordered_map<std::string, std::function<void(void)>> attr_dels_;
};

class PatternRewritePass : public Pass {
class IR_API PatternRewritePass : public Pass {
public:
PatternRewritePass(const std::string& name,
uint8_t opt_level,
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/pattern_rewrite/pattern_applicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

#include <algorithm>

#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"

#include "paddle/pir/pattern_rewrite/pattern_match.h"

namespace pir {
Expand Down
5 changes: 4 additions & 1 deletion paddle/pir/pattern_rewrite/pattern_applicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@

#include "paddle/pir/core/op_info.h"
#include "paddle/pir/core/operation.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/pir/pattern_rewrite/pattern_match.h"

namespace pir {

class FrozenRewritePatternSet;
class RewritePattern;
class Pattern;

class PatternApplicator {
public:
using CostModel = std::function<PatternBenefit(const Pattern&)>;
Expand Down
16 changes: 16 additions & 0 deletions paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,20 @@ std::pair<bool, int64_t> ApplyPatternsGreedily(
return std::make_pair(converged, num_rewrites);
}

IR_API std::pair<bool, int64_t> ApplyPatternsGreedily(
Operation* op,
const FrozenRewritePatternSet& patterns,
GreedyRewriteConfig config) {
bool sum_converged = true;
int64_t sum_num_rewrites = 0;
for (uint32_t i = 0; i < op->num_regions(); ++i) {
Region& region = op->region(i);
auto [converged, num_rewrites] =
ApplyPatternsGreedily(region, patterns, config);
sum_converged &= converged;
sum_num_rewrites += num_rewrites;
}
return std::make_pair(sum_converged, sum_num_rewrites);
}

} // namespace pir
19 changes: 4 additions & 15 deletions paddle/pir/pattern_rewrite/pattern_rewrite_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

#include "paddle/pir/core/dll_decl.h"
#include "paddle/pir/core/region.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/pir/pattern_rewrite/pattern_match.h"

namespace pir {

class FrozenRewritePatternSet;

/// This enum will control which ops will be added to the worklist during the
/// match rewrite process
enum class IR_API GreedyRewriteStrictness {
Expand Down Expand Up @@ -73,20 +73,9 @@ ApplyPatternsGreedily(Region& region, // NOLINT
GreedyRewriteConfig config = GreedyRewriteConfig());

/// Perform a match and rewrite process for all regions of a given op.
inline IR_API std::pair<bool, int64_t> ApplyPatternsGreedily(
IR_API std::pair<bool, int64_t> ApplyPatternsGreedily(
Operation* op,
const FrozenRewritePatternSet& patterns,
GreedyRewriteConfig config = GreedyRewriteConfig()) {
bool sum_converged = true;
int64_t sum_num_rewrites = 0;
for (uint32_t i = 0; i < op->num_regions(); ++i) {
Region& region = op->region(i);
auto [converged, num_rewrites] =
ApplyPatternsGreedily(region, patterns, config);
sum_converged &= converged;
sum_num_rewrites += num_rewrites;
}
return std::make_pair(sum_converged, sum_num_rewrites);
}
GreedyRewriteConfig config = GreedyRewriteConfig());

} // namespace pir

0 comments on commit fe3aae1

Please sign in to comment.