Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pir+auto parallel] add reshard op for input when needed #63072

Merged
merged 4 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,12 @@ pir::Value reshard(const pir::Value& x,
return reshard_op.result(0);
}

pir::Value reshard(const pir::Value& x,
const TensorDistAttribute& tensor_dist_attr) {
auto reshard_op = ApiBuilder::Instance().GetBuilder()->Build<ReShardOp>(
x, tensor_dist_attr);
return reshard_op.result(0);
}

} // namespace dialect
} // namespace paddle
5 changes: 5 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <vector>

#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
Expand All @@ -31,5 +32,9 @@ pir::Value shard_tensor(const pir::Value& x,
pir::Value reshard(const pir::Value& x,
const phi::distributed::ProcessMesh& process_mesh,
const std::vector<int64_t>& dims_mapping);

pir::Value reshard(const pir::Value& x,
const TensorDistAttribute& tensor_dist_attr);

} // namespace dialect
} // namespace paddle
12 changes: 7 additions & 5 deletions paddle/fluid/pir/dialect/distributed/ir/dist_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ common::DDim InferLocalDDim(const common::DDim& global_ddim,
TensorDistAttribute dist_attr) {
auto& mesh_dim = dist_attr.process_mesh_attr().shape();
auto& dim_mapping = dist_attr.dims_mapping();
PADDLE_ENFORCE_EQ(
global_ddim.size(),
dim_mapping.size(),
::common::errors::PreconditionNotMet(
"The global ddim size must equal to dim_mapping's size!"));
PADDLE_ENFORCE_EQ(global_ddim.size(),
dim_mapping.size(),
::common::errors::PreconditionNotMet(
"The global ddim size must equal to dim_mapping's "
"size, but bot %d vs %d",
global_ddim.size(),
dim_mapping.size()));
common::DDim local_ddim(global_ddim);
for (size_t i = 0; i < dim_mapping.size(); ++i) {
if (dim_mapping[i] != -1) {
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/pybind/dist_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <Python.h>
#include "pybind11/stl.h"

#include "paddle/fluid/pir/dialect/distributed/ir/dist_api.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
#include "paddle/fluid/pybind/dist_api.h"
#include "paddle/fluid/pybind/dist_static_op_function.h"
Expand Down Expand Up @@ -60,6 +61,10 @@ void BindTensorDistAttribute(py::module *m) {
print_stream << self;
return print_stream.str();
})
.def("__eq__",
[](TensorDistAttribute &self, const TensorDistAttribute &other) {
return self == other;
})
.def_property_readonly("process_mesh",
[](TensorDistAttribute &self) {
return self.process_mesh_attr().process_mesh();
Expand All @@ -86,12 +91,20 @@ void BindDistOpsAPI(pybind11::module *module) {
}
}

void BindOpsFunction(py::module *m) {
m->def("reshard_v2",
[](const pir::Value &x, const TensorDistAttribute &dist_attr) {
return reshard(x, dist_attr);
});
}

void BindDistApi(pybind11::module *module) {
auto ir_module = module->def_submodule("pir");
BindOperationDistAttribute(&ir_module);
BindTensorDistAttribute(&ir_module);
auto ops_modules = ir_module.def_submodule("ops");
BindDistOpsAPI(&ops_modules);
BindOpsFunction(&ops_modules);
}

} // namespace pybind
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/pybind/dist_static_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ static PyMethodDef DistOpsAPI[] = {
(PyCFunction)(void (*)(void))static_api_reshard,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for reshard."},

{nullptr, nullptr, 0, nullptr}};

} // namespace pybind
Expand Down
9 changes: 6 additions & 3 deletions paddle/phi/infermeta/spmd_rules/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,17 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x,

SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& out_grad) {
return {{out_grad.dist_attr(), out_grad.dist_attr()}, {out_grad.dist_attr()}};
auto dist_attr = CopyTensorDistAttrForOutput(out_grad.dist_attr());
dist_attr.set_dims_mapping(out_grad.dist_attr().dims_mapping());
return {{dist_attr, dist_attr}, {dist_attr}};
}

SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& out,
const DistMetaTensor& out_grad) {
return {{out_grad.dist_attr(), out_grad.dist_attr(), out_grad.dist_attr()},
{out_grad.dist_attr()}};
auto dist_attr = CopyTensorDistAttrForOutput(out_grad.dist_attr());
dist_attr.set_dims_mapping(out_grad.dist_attr().dims_mapping());
return {{dist_attr, dist_attr, dist_attr}, {dist_attr}};
}

bool DimsNotEqualOrHasBroadcastDim(const DistMetaTensor& x,
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .dist_saver import DistributedSaver
from .helper import ProgramHelper
from .parallelizer_v2 import Parallelizer
from .pir_pass import apply_partition_pass
from .planner_v2 import Planner
from .process_group import get_all_process_groups, new_process_group

Expand Down Expand Up @@ -675,7 +676,7 @@ def _parallel_pir(self, mode):
# TODO(JZ-LIANG) Step 3.1: Partition Pass
# insert reshard op if operand tensor's placements if different from what the cumsumer op need.
# Partition the computation graph into different pipeline stage if need.
# dist_program = apply_partition_pass(dist_program)
dist_program = apply_partition_pass(dist_program)

# TODO(hitywt) Step 3.2: Reshard Pass
# resolute the reshard op into special collective operation.
Expand Down
40 changes: 40 additions & 0 deletions python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle


def apply_partition_pass(program):
new_program = program.clone()
with paddle.static.program_guard(new_program):
for op in new_program.global_block().ops:
# assert len(op.operands()) == len(op.dist_attr().operand_dist_attrs()), f'The number of operand and operand_dist_attrs are not equal in op: {op}'
for var, operand_dist_attr in zip(
op.operands(), op.dist_attr().operand_dist_attrs()
):
if (
var.source().is_dist_dense_tensor_type()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In scenario where src_dist_attr and dst_dist_attr have different mesh (e.g. Pipeline Parallelism), it would be better to insert two reshard ops.
one reshard op's mesh = src_dist_attr's mesh
the other's mesh = dst_dist_attr's mesh

therefore in the following (pipeline stage) pruning pass, different stage will keep the reshard op by the mesh it need and remove the other one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be refined in the next PR

and var.source().dist_attr() != operand_dist_attr
):
paddle.pir.set_insertion_point(op)
# insert reshard
reshard_var = paddle._pir_ops.reshard_v2(
var.source(), operand_dist_attr
)
var.set_source(reshard_var)
return new_program


def apply_reshard_pass(program):
pass
9 changes: 6 additions & 3 deletions test/auto_parallel/pir/test_to_static_pir_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self, mesh):
)

def forward(self, x):
x.stop_gradient = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not need to make x require for gradient, the relu_grad in backward will trigger the partial-->replicated allreduce

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is needed, otherwise, relu_grad is not executed.

out = self.relu_0(x) # triggle backward partial allreduce
out = self.linear_0(out)
out = self.relu_1(out)
Expand Down Expand Up @@ -138,6 +139,8 @@ def test_to_static_program(self):
backward_op_list = [
"pd_op.sgd_",
"pd_op.sgd_",
"pd_op.relu_grad",
"dist_op.reshard",
"pd_op.matmul_grad",
"pd_op.relu_grad",
"pd_op.matmul_grad",
Expand Down Expand Up @@ -225,10 +228,10 @@ def test_to_static_program(self):
tensor._local_shape, [BATCH_SIZE, CLASS_NUM]
)
elif matmul_grad_idx == 1:
self.assertEqual(tensor.dist_attr().dims_mapping, [-1, 0])
self.assertEqual(tensor.dist_attr().partial_dims, set())
self.assertEqual(tensor.dist_attr().dims_mapping, [-1, -1])
self.assertEqual(tensor.dist_attr().partial_dims, {0})
self.assertEqual(
tensor._local_shape, [BATCH_SIZE, IMAGE_SIZE // 2]
tensor._local_shape, [BATCH_SIZE, IMAGE_SIZE]
)
matmul_grad_idx += 1
if op.name() == 'pd_op.sgd_':
Expand Down