Skip to content

Commit

Permalink
add to engine
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed Apr 28, 2024
1 parent 1dceeb5 commit a9b0aae
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 9 deletions.
13 changes: 12 additions & 1 deletion python/paddle/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,18 @@ def _get_program_and_executor(self, cached_data):
):
pm = pir.PassManager()
for p in new_program._pass_opt['pass_list']:
pm.add_pass(p, {})
# Temporary implementation, it will be refined when auto_parallel refactored
if p == 'eliminate_transpose':
from paddle.distributed.auto_parallel.static.pir_pass import (
eliminate_transpose_by_reshape,
)

for job_type in plan.job_types():
ir_program = plan.ir_program(job_type)
eliminate_transpose_by_reshape(ir_program)
else:
pm.add_pass(p, {})

for job_type in plan.job_types():
ir_program = plan.ir_program(job_type)
pm.run(ir_program)
Expand Down
1 change: 0 additions & 1 deletion python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,6 @@ def _parallel_pir(self, mode):
# Part 1: Complete program
# Step 1.1: Mix2Dense Pass
# TODO(JZ-LIANG) regulization pass with pass management.
print(mix_fw_program)
dist_program = paddle.base.libpaddle.pir.apply_mix2dist_pass(
mix_fw_program
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,16 @@
use_new_executor,
)

NEW_IR_PASS = [
PIR_PASS = [
'fused_gemm_epilogue_pass',
'fused_linear_param_grad_add_pass',
'fused_dropout_add_pass',
]

PIR_PYTHON_PASS = [
'eliminate_transpose',
]


class Parallelizer:
def __init__(self, mode, completer, dist_context):
Expand Down Expand Up @@ -513,13 +517,13 @@ def _apply_post_optimization(
ir_pass_list = []
if self.is_train and self._strategy.fused_passes.enable:
if len(self._strategy.fused_passes.fused_passes_list) > 0:
new_pass_list = []
program_pass_list = []
for p in self._strategy.fused_passes.fused_passes_list:
if p in NEW_IR_PASS and enable_ir:
if enable_ir and p in (PIR_PASS + PIR_PYTHON_PASS):
ir_pass_list.append(p)
else:
new_pass_list.append(new_pass(p))
pass_manager = PassManager(new_pass_list)
program_pass_list.append(new_pass(p))
pass_manager = PassManager(program_pass_list)
pass_manager.apply([main_program], [startup_program])

main_program._pass_opt = {}
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ def apply_reshard_pass(program):
# We found that, when bs=1, which is the common case in llm
# training, the transpose is equal to reshape.
# So, this pass is to haddle the specific case.
def replace_transpose_by_reshape(program):
def eliminate_transpose_by_reshape(program):
with paddle.static.program_guard(program):
for op in program.global_block().ops:
if op.name() == 'pd_op.transpose':
if op.name() == 'pd_op.transpose' or op.name() == 'pd_op.transpose':
var = op.operand(0).source()
rank = len(var.shape)
perm = op.attrs()['perm']
Expand Down

0 comments on commit a9b0aae

Please sign in to comment.