From a9b0aae09511bcca60fe42e0e38bf25549140a3c Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Sun, 28 Apr 2024 17:43:46 +0800 Subject: [PATCH] add to engine --- python/paddle/base/executor.py | 13 ++++++++++++- .../distributed/auto_parallel/static/engine.py | 1 - .../auto_parallel/static/parallelizer_v2.py | 14 +++++++++----- .../distributed/auto_parallel/static/pir_pass.py | 4 ++-- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index 2328d99d6fd72..7311991b021cd 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -1089,7 +1089,18 @@ def _get_program_and_executor(self, cached_data): ): pm = pir.PassManager() for p in new_program._pass_opt['pass_list']: - pm.add_pass(p, {}) + # Temporary implementation, it will be refined when auto_parallel refactored + if p == 'eliminate_transpose': + from paddle.distributed.auto_parallel.static.pir_pass import ( + eliminate_transpose_by_reshape, + ) + + for job_type in plan.job_types(): + ir_program = plan.ir_program(job_type) + eliminate_transpose_by_reshape(ir_program) + else: + pm.add_pass(p, {}) + for job_type in plan.job_types(): ir_program = plan.ir_program(job_type) pm.run(ir_program) diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index f88ba0c769e86..4fe6aa0632605 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -632,7 +632,6 @@ def _parallel_pir(self, mode): # Part 1: Complete program # Step 1.1: Mix2Dense Pass # TODO(JZ-LIANG) regulization pass with pass management. - print(mix_fw_program) dist_program = paddle.base.libpaddle.pir.apply_mix2dist_pass( mix_fw_program ) diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index e47c128c822a1..6109c0197cc42 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -32,12 +32,16 @@ use_new_executor, ) -NEW_IR_PASS = [ +PIR_PASS = [ 'fused_gemm_epilogue_pass', 'fused_linear_param_grad_add_pass', 'fused_dropout_add_pass', ] +PIR_PYTHON_PASS = [ + 'eliminate_transpose', +] + class Parallelizer: def __init__(self, mode, completer, dist_context): @@ -513,13 +517,13 @@ def _apply_post_optimization( ir_pass_list = [] if self.is_train and self._strategy.fused_passes.enable: if len(self._strategy.fused_passes.fused_passes_list) > 0: - new_pass_list = [] + program_pass_list = [] for p in self._strategy.fused_passes.fused_passes_list: - if p in NEW_IR_PASS and enable_ir: + if enable_ir and p in (PIR_PASS + PIR_PYTHON_PASS): ir_pass_list.append(p) else: - new_pass_list.append(new_pass(p)) - pass_manager = PassManager(new_pass_list) + program_pass_list.append(new_pass(p)) + pass_manager = PassManager(program_pass_list) pass_manager.apply([main_program], [startup_program]) main_program._pass_opt = {} diff --git a/python/paddle/distributed/auto_parallel/static/pir_pass.py b/python/paddle/distributed/auto_parallel/static/pir_pass.py index 49a9208937e0d..c592a8c3518c5 100644 --- a/python/paddle/distributed/auto_parallel/static/pir_pass.py +++ b/python/paddle/distributed/auto_parallel/static/pir_pass.py @@ -164,10 +164,10 @@ def apply_reshard_pass(program): # We found that, when bs=1, which is the common case in llm # training, the transpose is equal to reshape. # So, this pass is to haddle the specific case. -def replace_transpose_by_reshape(program): +def eliminate_transpose_by_reshape(program): with paddle.static.program_guard(program): for op in program.global_block().ops: - if op.name() == 'pd_op.transpose': + if op.name() == 'pd_op.transpose' or op.name() == 'pd_op.transpose': var = op.operand(0).source() rank = len(var.shape) perm = op.attrs()['perm']