Skip to content

Commit

Permalink
remove unuseful c_allgather op for pir autp parallel. (#64465)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored May 21, 2024
1 parent fd0af9f commit cc38216
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,22 @@ void VerifyDenseBlock(pir::Block* block) {
}
}

void RemoveUnusefulCallgatherOp(pir::Block* block) {
std::vector<pir::Operation*> del_ops;
for (auto& op : *block) {
if (op.isa<CAllgatherOp>()) {
auto nrank = op.attribute<pir::Int32Attribute>("nranks").data();
if (nrank == 1) {
op.result(0).ReplaceAllUsesWith(op.operand_source(0));
del_ops.emplace_back(&op);
}
}
}
for (auto op : del_ops) {
op->Erase();
}
}

std::shared_ptr<pir::Program> DistToDensePass(pir::Program* prog) {
if (FLAGS_print_ir) {
VLOG(0) << "IR before DistToDense Pass = " << *prog;
Expand All @@ -135,6 +151,7 @@ std::shared_ptr<pir::Program> DistToDensePass(pir::Program* prog) {
ctx->GetOrRegisterDialect<DistDialect>();

ProcessDistBlock(new_prog->block());
RemoveUnusefulCallgatherOp(new_prog->block());
VLOG(6) << "IR before VerifyDenseBlock Pass = " << *new_prog;
VerifyDenseBlock(new_prog->block());

Expand Down
35 changes: 27 additions & 8 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from paddle.distributed import fleet
from paddle.framework import (
IrGraph,
_current_expected_place as _get_device,
_current_expected_place_ as _get_device,
core,
in_dynamic_mode,
)
Expand Down Expand Up @@ -714,6 +714,10 @@ def _prepare_program(self, mode, init_parameters=True):
# TODO(zhiqiu): fit the processes below for pir
if self._in_pir_mode:
self._parallel_pir(mode)
# Init comm
self._init_comm()
# startup program
self._initialize(mode, init_parameters)
self._has_prepared[mode] = True
return
# Do the planning process
Expand Down Expand Up @@ -1024,22 +1028,35 @@ def _init_comm(self):
for process_group in all_process_groups:
process_group.instantiate()

def _init_lr(self):
def _init_lr(self, main_program):
# hack to find learning_rate op
lr_name = None
for op in self.main_program.global_block().ops:
for op in main_program.global_block().ops:
if (
op.name() == "pd_op.data"
and 'learning_rate' in op.attrs()["name"]
):
lr_name = op.attrs()["name"]
break
if (
op.name() == "builtin.parameter"
and 'learning_rate' in op.attrs()["parameter_name"]
):
lr_name = op.attrs()["parameter_name"]
break

if lr_name is not None:
buffer_tensor = global_scope().var(lr_name).get_tensor()
buffer_tensor.set(
np.float32(self._optimizer._learning_rate), self._place
)
from paddle.optimizer.lr import LRScheduler

if isinstance(self._optimizer._learning_rate, float):
buffer_tensor.set(
np.float32(self._optimizer._learning_rate), self._place
)
elif isinstance(self._optimizer._learning_rate, LRScheduler):
buffer_tensor.set(
np.float32(self._optimizer._learning_rate()), self._place
)

def _initialize(self, mode, init_parameters=True):
self._place = _get_device()
Expand All @@ -1058,10 +1075,12 @@ def _initialize(self, mode, init_parameters=True):
# 6. vpp init adaption

self.program_helper.init_pir(
self._pir_dense_main_progs[mode], self._place
self._pir_dist_main_progs[mode], self._place
)

self._init_lr()
self._init_lr(self._pir_dense_main_progs[mode])
if self._executor is None:
self._executor = paddle.static.Executor(self._place)
return

if self._strategy.seed:
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def reshard_combine_value(op, operand, attr):
reshard_vars = []
for inner_operand, inner_attr in zip(combine_op.operands(), array_attr):
reshard_vars.append(reshard_single_value(op, inner_operand, inner_attr))
paddle.pir.set_insertion_point(combine_op)
paddle.pir.set_insertion_point(op)
return paddle._C_ops.builtin_combine(reshard_vars)


Expand Down

0 comments on commit cc38216

Please sign in to comment.