Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add global to sub mesh reshard func or auto parallal. #64418

Merged
merged 1 commit into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,10 @@ def gen_infermeta_func_str(args, op_info):
spmd_params = op_info.kernel_map['param']
else:
spmd_params = op_info.input_name_list
# TODO(GhostScreaming): specialized case for reshape_grad
# xshape is not kernel params, but inferspmd needs it.
if "reshape_grad" in op_info.kernel_map['func'][0]:
spmd_params = ["xshape"] + spmd_params
op_info.spmd_params = spmd_params

infermeta_inputs_str = get_infermeta_inputs_str(
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@
infer_meta :
func : KernelWithXShapeInferMeta
param : [xshape, out_grad]
spmd_rule: StaticReshapeGradInferSpmd
kernel :
func : reshape_grad
param : [out_grad]
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/dist_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/fluid/pir/dialect/distributed/transforms/dist_to_dense_pass.h"
#include "paddle/fluid/pybind/dist_api.h"
#include "paddle/fluid/pybind/dist_static_op_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
#include "paddle/phi/core/enforce.h"

namespace py = pybind11;
Expand All @@ -40,6 +41,7 @@ struct type_caster<paddle::flat_hash_map<Key, Value, Hash, Equal, Alloc>>
} // namespace pybind11

using paddle::dialect::OperationDistAttribute;
using paddle::dialect::ProcessMeshAttribute;
using paddle::dialect::TensorDistAttribute;

namespace paddle {
Expand Down Expand Up @@ -122,6 +124,7 @@ OperationDistAttribute CreateOperationDistAttribute(
void BindDistUtils(pybind11::module *m) {
m->def("create_tensor_dist_attribute", CreateTensorDistAttribute);
m->def("create_op_dist_attribute", CreateOperationDistAttribute);
m->def("get_sub_meshes", phi::distributed::GetSubMeshes);
m->def("cvt_to_dist_type", &dialect::CvtToPirDistType);
}

Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,14 @@ void BindValue(py::module *m) {
return_value_policy::reference)
.def("numel", [](Value self) { return phi::product(GetValueDims(self)); })
.def("type", &Value::type)
.def("index",
[](Value self) -> uint32_t {
if (auto op_result = self.dyn_cast<OpResult>()) {
return op_result.index();
}
PADDLE_THROW(phi::errors::InvalidArgument(
"only support accesss index from op_result."));
})
.def("is_dense_tensor_type",
[](Value self) { return self.type().isa<DenseTensorType>(); })
.def("is_selected_row_type",
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/spmd_rules/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,5 +351,12 @@ SpmdInfo ReshapeGradInferSpmd(const DistMetaTensor& x_shape,
return {{out_grad_dist_dst}, {x_shape_dist_dst}};
}

SpmdInfo StaticReshapeGradInferSpmd(const DistMetaTensor& x_shape,
const DistMetaTensor& out_grad) {
auto spmd_info = ReshapeGradInferSpmd(x_shape, out_grad);
spmd_info.first.insert(spmd_info.first.begin(), x_shape.dist_attr());
return spmd_info;
}

} // namespace distributed
} // namespace phi
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/spmd_rules/reshape.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,8 @@ SpmdInfo ReshapeInferSpmdDynamic(const DistMetaTensor& x,
SpmdInfo ReshapeGradInferSpmd(const DistMetaTensor& x_shape,
const DistMetaTensor& out_grad);

SpmdInfo StaticReshapeGradInferSpmd(const DistMetaTensor& x_shape,
const DistMetaTensor& out_grad);

} // namespace distributed
} // namespace phi
4 changes: 4 additions & 0 deletions python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def apply_reshard_pass(program):
not var.initialized() or var.dist_attr() == src_dist_attr
), f"The dist_attr of reshard op's input and operand should be equal, but got {var.dist_attr()} and {src_dist_attr}"

if src_dist_attr == dst_dist_attr:
op.result(0).replace_all_uses_with(var)
op.erase()
continue
reshard_func = choose_reshard_func(src_dist_attr, dst_dist_attr)
assert (
reshard_func is not None
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from .base_reshard_func import ReshardFunction


class GlobaleToSubMeshFunction(ReshardFunction):
def is_suitable(self, src_dist_attr, dst_dist_attr):
if 0 in src_dist_attr.dims_mapping or 0 in src_dist_attr.partial_status:
return False
in_mesh = src_dist_attr.process_mesh
out_mesh = dst_dist_attr.process_mesh
if in_mesh.ndim != out_mesh.ndim + 1:
return False
sub_meshes = paddle.base.libpaddle.pir.get_sub_meshes(in_mesh)
return out_mesh in sub_meshes

def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
if src_value.has_one_use():
src_value.update_dist_attr(dst_dist_attr)
prev_op = src_value.get_defining_op()
op_dist_attr = prev_op.dist_attr
op_mesh = op_dist_attr.process_mesh
operands = op_dist_attr.operands()
results = op_dist_attr.results()
results[src_value.index()] = dst_dist_attr
prev_op.dist_attr = (
paddle.base.libpaddle.pir.create_op_dist_attribute(
op_mesh, operands, results
)
)
return src_value
else:
dst_value = paddle._C_ops.share_data_(src_value)
share_data_op = dst_value.get_defining_op()
# set dist type and dist attr
dst_value.set_type(dst_type)
share_data_op.dist_attr = (
paddle.base.libpaddle.pir.create_op_dist_attribute(
dst_dist_attr.process_mesh, [src_dist_attr], [dst_dist_attr]
)
)
return dst_value
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .base_reshard_func import register_reshard_func
from .global_to_sub_mesh_func import GlobaleToSubMeshFunction
from .nd_mesh_reshard_func import (
NdMeshReshardFunction,
NdMeshReshardFunctionCrossMesh,
Expand Down Expand Up @@ -42,6 +43,7 @@ def register_reshard_funcs():
register_reshard_func(SToRReshardFunctionCrossMesh())
register_reshard_func(NdMeshReshardFunction())
register_reshard_func(NdMeshReshardFunctionCrossMesh())
register_reshard_func(GlobaleToSubMeshFunction())


register_reshard_funcs()
1 change: 1 addition & 0 deletions test/dygraph_to_static/test_tensor_attr_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
'all_used_ops',
'append',
'first_use',
'index',
'get_defining_op',
'has_one_use',
'has_name',
Expand Down