From acec9b7a18cc976d1f2bead6f89630113cdfd1df Mon Sep 17 00:00:00 2001 From: winter-wang <1030748926@qq.com> Date: Tue, 9 Apr 2024 16:46:18 +0800 Subject: [PATCH 1/3] add vjp interface for reshard op. --- .../pir/dialect/distributed/ir/dist_api.cc | 17 ++- .../pir/dialect/distributed/ir/dist_api.h | 8 +- .../dialect/distributed/ir/dist_dialect.cc | 2 +- .../pir/dialect/distributed/ir/dist_op.cc | 60 +++++++++-- .../pir/dialect/distributed/ir/dist_op.h | 15 ++- .../operator/utils/op_yaml_info_util.h | 5 +- paddle/fluid/pybind/dist_static_op_function.h | 17 ++- .../paddle/distributed/auto_parallel/api.py | 9 +- .../auto_parallel/static/engine.py | 1 + test/auto_parallel/pir/CMakeLists.txt | 1 + test/auto_parallel/pir/test_reshard.py | 101 ++++++++++++++++++ test/cpp/pir/distributed/dist_dialect_test.cc | 12 +-- 12 files changed, 215 insertions(+), 33 deletions(-) create mode 100644 test/auto_parallel/pir/test_reshard.py diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_api.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_api.cc index 3382fa18b9090..6ba2b16d00df2 100644 --- a/paddle/fluid/pir/dialect/distributed/ir/dist_api.cc +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_api.cc @@ -45,23 +45,20 @@ pir::Value shard_tensor(const pir::Value& x, return shard_tensor_op.out(); } -pir::Value reshard(const pir::Value& x, - const phi::distributed::ProcessMesh& process_mesh, - const std::vector& dims_mapping) { +pir::Value reshard( + const pir::Value& x, + const phi::distributed::ProcessMesh& process_mesh, + const std::vector& dims_mapping, + const flat_hash_map& partial_status) { pir::IrContext* ctx = pir::IrContext::Instance(); - // TODO(ywt01) get partial_status by func parameter - paddle::flat_hash_map partial_status; TensorDistAttribute tensor_dist_attr = TensorDistAttribute::get(ctx, process_mesh, dims_mapping, partial_status); - - auto reshard_op = ApiBuilder::Instance().GetBuilder()->Build( - x, tensor_dist_attr); - return reshard_op.result(0); + return reshard(x, tensor_dist_attr); } pir::Value reshard(const pir::Value& x, const TensorDistAttribute& tensor_dist_attr) { - auto reshard_op = ApiBuilder::Instance().GetBuilder()->Build( + auto reshard_op = ApiBuilder::Instance().GetBuilder()->Build( x, tensor_dist_attr); return reshard_op.result(0); } diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_api.h b/paddle/fluid/pir/dialect/distributed/ir/dist_api.h index 18aa1bb32ca64..5706afa63c165 100644 --- a/paddle/fluid/pir/dialect/distributed/ir/dist_api.h +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_api.h @@ -29,9 +29,11 @@ pir::Value shard_tensor(const pir::Value& x, const phi::distributed::ProcessMesh& process_mesh, const std::vector& dims_mapping); -pir::Value reshard(const pir::Value& x, - const phi::distributed::ProcessMesh& process_mesh, - const std::vector& dims_mapping); +pir::Value reshard( + const pir::Value& x, + const phi::distributed::ProcessMesh& process_mesh, + const std::vector& dims_mapping, + const flat_hash_map& partial_status = {}); pir::Value reshard(const pir::Value& x, const TensorDistAttribute& tensor_dist_attr); diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc index 0ea42bf6e093d..5834ba6262f3f 100644 --- a/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc @@ -35,7 +35,7 @@ void DistDialect::initialize() { TensorDistAttribute, OperationDistAttribute>(); RegisterTypes(); - RegisterOps(); + RegisterOps(); } void DistDialect::PrintType(pir::Type type, std::ostream &os) const { diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_op.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_op.cc index cc06461e66d55..d419ea7d4d165 100644 --- a/paddle/fluid/pir/dialect/distributed/ir/dist_op.cc +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_op.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/pir/dialect/distributed/ir/dist_op.h" #include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" #include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -28,7 +29,7 @@ namespace paddle { namespace dialect { const char* ShardTensorOp::attributes_name[1] = {"op_dist_attr"}; -const char* ReShardOp::attributes_name[1] = {"op_dist_attr"}; +const char* ReshardOp::attributes_name[1] = {"op_dist_attr"}; void ShardTensorOp::VerifySig() { VLOG(4) @@ -159,8 +160,54 @@ void ShardTensorOp::Build(pir::Builder& builder, ::pir::PassStopGradientsDefaultly(argument); } -void ReShardOp::VerifySig() { - VLOG(4) << "Start Verifying inputs, outputs and attributes for: ReShardOp."; +OpInfoTuple ReshardOp::GetOpInfo() { + return OpInfoTuple( + {OpInputInfo()}, {}, {OpOutputInfo()}, OpRunTimeInfo(), "reshard"); +} + +std::vector> ReshardOp::Vjp( + pir::Operation* op, + const std::vector>& inputs_, + const std::vector>& outputs, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + VLOG(6) << "Start call vjp for reshard op."; + PADDLE_ENFORCE_EQ( + inputs_.size(), + 1, + common::errors::InvalidArgument("reshard op's inputs' size should be 1")); + PADDLE_ENFORCE_EQ(inputs_[0].size(), + 1, + common::errors::InvalidArgument( + "reshard op's inputs[0]'s size should be 1")); + auto dist_type = inputs_[0][0].type().dyn_cast(); + + PADDLE_ENFORCE_NOT_NULL( + dist_type, + common::errors::InvalidArgument( + "Currently, reshard op's inputs type must be dist type.")); + + PADDLE_ENFORCE_EQ(out_grads.size(), + 1, + common::errors::InvalidArgument( + "reshard op's outputs grad size should be 1")); + + PADDLE_ENFORCE_EQ(out_grads[0].size(), + 1, + common::errors::InvalidArgument( + "reshard op's outputs grad[0] size should be 1")); + + auto& builder = *ApiBuilder::Instance().GetBuilder(); + + auto grad_op = + builder.Build(out_grads[0][0], dist_type.tensor_dist_attr()); + + VLOG(6) << "End call vjp for reshard op."; + + return {std::vector{grad_op->result(0)}}; +} +void ReshardOp::VerifySig() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: ReshardOp."; VLOG(4) << "Verifying inputs:"; { auto input_size = num_operands(); @@ -224,11 +271,11 @@ void ReShardOp::VerifySig() { VLOG(4) << "End Verifying for: ShardTensorOp."; } -void ReShardOp::Build(pir::Builder& builder, +void ReshardOp::Build(pir::Builder& builder, pir::OperationArgument& argument, pir::Value input, TensorDistAttribute tensor_dist_attr) { - VLOG(4) << "Start build ReShardOp"; + VLOG(4) << "Start build ReshardOp"; paddle::dialect::DistDenseTensorType input_tensor_type; if (input.type().isa()) { @@ -270,10 +317,11 @@ void ReShardOp::Build(pir::Builder& builder, tensor_dist_attr, local_shape); argument.AddOutput(out_dist_tensor_type); + ::pir::PassStopGradientsDefaultly(argument); } } // namespace dialect } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ShardTensorOp) -IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ReShardOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ReshardOp) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_op.h b/paddle/fluid/pir/dialect/distributed/ir/dist_op.h index 7ae81a0040702..638fb430eaf4e 100644 --- a/paddle/fluid/pir/dialect/distributed/ir/dist_op.h +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_op.h @@ -15,6 +15,8 @@ #pragma once #include +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/interface/vjp.h" #include "paddle/pir/include/core/builder.h" #include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/core/op_base.h" @@ -39,7 +41,7 @@ class ShardTensorOp : public pir::Op { void VerifySig(); }; -class ReShardOp : public pir::Op { +class ReshardOp : public pir::Op { public: using Op::Op; static const char* name() { return "dist_op.reshard"; } @@ -49,10 +51,19 @@ class ReShardOp : public pir::Op { pir::OperationArgument& argument, // NOLINT pir::Value input, TensorDistAttribute tensor_dist_attr); + + static OpInfoTuple GetOpInfo(); + static std::vector> Vjp( + pir::Operation* op, + const std::vector>& inputs_, + const std::vector>& outputs, + const std::vector>& out_grads, + const std::vector>& stop_gradients); + void VerifySig(); }; } // namespace dialect } // namespace paddle IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShardTensorOp) -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ReShardOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ReshardOp) diff --git a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h index 86370dd0cc6c1..e8719d4adb73e 100644 --- a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h +++ b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h @@ -98,8 +98,9 @@ struct OpRunTimeInfo { std::vector skip_transform_inputs; pir::AttributeMap extra_args_default_value; std::vector data_format_tensors; - bool is_onednn_only; - bool dynamic_fallback; + bool is_onednn_only = false; + bool dynamic_fallback = false; + OpRunTimeInfo() = default; OpRunTimeInfo(const std::string& infer_meta_func, const std::vector& infer_meta_param, diff --git a/paddle/fluid/pybind/dist_static_op_function.h b/paddle/fluid/pybind/dist_static_op_function.h index afd71b7521567..5d5ef704a5284 100644 --- a/paddle/fluid/pybind/dist_static_op_function.h +++ b/paddle/fluid/pybind/dist_static_op_function.h @@ -69,9 +69,22 @@ static PyObject *static_api_reshard(PyObject *self, PyObject *dims_mapping_obj = PyTuple_GET_ITEM(args, 2); auto dims_mapping = CastPyArg2VectorOfInt64(dims_mapping_obj, 2); + PyObject *placements_obj = PyTuple_GET_ITEM(args, 3); + auto placements = CastPyArg2VectorOfPlacement(placements_obj, 3); + + paddle::flat_hash_map partial_status; + for (size_t i = 0; i < placements.size(); ++i) { + auto &p = placements[i]; + if (p->is_partial()) { + partial_status.insert( + {i, + dynamic_cast(*p).get_reduce_type()}); + } + } + // Call ir static api - auto static_api_out = - paddle::dialect::reshard(input, process_mesh, dims_mapping); + auto static_api_out = paddle::dialect::reshard( + input, process_mesh, dims_mapping, partial_status); return ToPyObject(static_api_out); } catch (...) { diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index f7c69e1fe6464..4c675946373a6 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -28,12 +28,16 @@ EagerParamBase, Variable, default_main_program, + in_pir_mode, ) from paddle.distributed.auto_parallel import Engine, strategy as auto_strategy from paddle.distributed.auto_parallel.interface import ( shard_tensor as shard_tensor_static, ) -from paddle.distributed.auto_parallel.placement_type import to_placements +from paddle.distributed.auto_parallel.placement_type import ( + to_dim_map, + to_placements, +) from paddle.distributed.auto_parallel.static.completion import ( mark_as_sharding_propagation_skip_op, ) @@ -388,6 +392,9 @@ def reshard(dist_tensor, mesh, placements): dist_attr._set_partial_dims(partial_dims) return paddle.base.core.reshard(dist_tensor, dist_attr) + elif in_pir_mode(): + dim_map = to_dim_map(placements, dist_tensor.ndim) + return paddle._pir_ops.reshard(dist_tensor, mesh, dim_map, placements) else: assert isinstance( dist_tensor, Variable diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index b5907242cacf8..2bda746dfa92b 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -712,6 +712,7 @@ def _prepare_program(self, mode, init_parameters=True): # TODO(zhiqiu): fit the processes below for pir if self._in_pir_mode: self._parallel_pir(mode) + self._has_prepared[mode] = True return # Do the planning process self._plan(mode) diff --git a/test/auto_parallel/pir/CMakeLists.txt b/test/auto_parallel/pir/CMakeLists.txt index a6de706d70871..c292a78517de0 100644 --- a/test/auto_parallel/pir/CMakeLists.txt +++ b/test/auto_parallel/pir/CMakeLists.txt @@ -12,6 +12,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_pir_mse_spmd MODULES test_mse_spmd_rule ENVS FLAGS_enable_pir_api=1) py_test_modules(test_mlp MODULES test_mlp ENVS FLAGS_enable_pir_api=1) + py_test_modules(test_reshard MODULES test_reshard ENVS FLAGS_enable_pir_api=1) set_tests_properties(test_mlp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 60) endif() diff --git a/test/auto_parallel/pir/test_reshard.py b/test/auto_parallel/pir/test_reshard.py new file mode 100644 index 0000000000000..dbde9b8865a3a --- /dev/null +++ b/test/auto_parallel/pir/test_reshard.py @@ -0,0 +1,101 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +import sys +import unittest + +import paddle +import paddle.distributed as dist +from paddle import nn + +sys.path.append(str(pathlib.Path(__file__).resolve().parents[0])) +from test_to_static_pir_program import ( + BATCH_SIZE, + IMAGE_SIZE, + DemoNet, + create_data_loader, +) + + +class ReshardDemoNet(DemoNet): + def __init__(self, mesh, shard=True): + super().__init__(mesh, shard=True) + + def forward(self, x): + out = DemoNet.forward(self, x) + out = dist.reshard(out, self._mesh, [dist.Shard(0)]) + return out + + +class TestToStaticPirProgramTrain(unittest.TestCase): + def test_to_static_program(self): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + layer = ReshardDemoNet(mesh) + opt = paddle.optimizer.SGD( + learning_rate=0.1, parameters=layer.parameters() + ) + loss_fn = nn.MSELoss() + loader = create_data_loader() + dist_loader = dist.shard_dataloader(loader, meshes=[mesh]) + dist_model = dist.to_static(layer, dist_loader, loss_fn, opt) + engine = dist_model._engine + engine._build("train") + dist_program = engine._fwd_main_progs["train"] + dist_program = paddle.base.libpaddle.pir.apply_mix2dist_pass( + dist_program + ) + loss = dist_program.get_output_value_by_name(engine._loss_names[0]) + with paddle.static.program_guard(dist_program): + params_grads = paddle.autograd.ir_backward.append_backward(loss) + engine._optimizer._apply_optimize( + loss, startup_program=None, params_grads=params_grads + ) + + index = 0 + for op in dist_program.global_block().ops: + if op.name() == 'dist_op.reshard': + if index == 0: + # forward reshard op + self.fwd_input = op.operand_source(0) + self.assertEqual( + self.fwd_input.dist_attr().dims_mapping, [-1, -1] + ) + self.assertEqual( + self.fwd_input.dist_attr().partial_dims, set() + ) + self.assertEqual( + self.fwd_input._local_shape, + [BATCH_SIZE, IMAGE_SIZE // 2], + ) + self.fwd_output = op.result(0) + self.assertEqual( + self.fwd_output.dist_attr().dims_mapping, [0, -1] + ) + self.assertEqual( + self.fwd_output.dist_attr().partial_dims, set() + ) + self.assertEqual( + self.fwd_output._local_shape, + [BATCH_SIZE / 2, IMAGE_SIZE // 2], + ) + elif index == 1: + # backward reshard op + self.assertEqual(op.result(0).type(), self.fwd_input.type()) + index += 1 + self.assertEqual(index, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/cpp/pir/distributed/dist_dialect_test.cc b/test/cpp/pir/distributed/dist_dialect_test.cc index 4a0e477b09ae3..8399abc30cb0b 100644 --- a/test/cpp/pir/distributed/dist_dialect_test.cc +++ b/test/cpp/pir/distributed/dist_dialect_test.cc @@ -345,8 +345,8 @@ TEST(shard_tensor_op_replicate_test, base) { auto dst_mesh_attr = ProcessMeshAttribute::get(ctx, dst_process_mesh); auto dst_tensor_dist_attr = TensorDistAttribute::get( ctx, dst_mesh_attr, dst_dims_mapping, partial_status); - paddle::dialect::ReShardOp reshard_op = - builder.Build(shard_op.out(), + paddle::dialect::ReshardOp reshard_op = + builder.Build(shard_op.out(), dst_tensor_dist_attr); EXPECT_TRUE(reshard_op.result(0).type().isa()); @@ -428,8 +428,8 @@ TEST(shard_tensor_op_shard_row_test, base) { auto dst_mesh_attr = ProcessMeshAttribute::get(ctx, dst_process_mesh); auto dst_tensor_dist_attr = TensorDistAttribute::get( ctx, dst_mesh_attr, dims_mapping, partial_status); - paddle::dialect::ReShardOp reshard_op = - builder.Build(shard_op.out(), + paddle::dialect::ReshardOp reshard_op = + builder.Build(shard_op.out(), dst_tensor_dist_attr); EXPECT_TRUE(reshard_op.result(0).type().isa()); @@ -511,8 +511,8 @@ TEST(shard_tensor_op_shard_col_test, base) { auto dst_mesh_attr = ProcessMeshAttribute::get(ctx, dst_process_mesh); auto dst_tensor_dist_attr = TensorDistAttribute::get( ctx, dst_mesh_attr, dst_dims_mapping, partial_status); - paddle::dialect::ReShardOp reshard_op = - builder.Build(shard_op.out(), + paddle::dialect::ReshardOp reshard_op = + builder.Build(shard_op.out(), dst_tensor_dist_attr); EXPECT_TRUE(reshard_op.result(0).type().isa()); From cdbcc1ac56fd3a19e5846f264312392805733bcd Mon Sep 17 00:00:00 2001 From: winter-wang <1030748926@qq.com> Date: Thu, 11 Apr 2024 11:21:51 +0800 Subject: [PATCH 2/3] fix pr comment --- paddle/fluid/pybind/dist_static_op_function.h | 34 ++++++++++++++----- .../paddle/distributed/auto_parallel/api.py | 4 +-- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/pybind/dist_static_op_function.h b/paddle/fluid/pybind/dist_static_op_function.h index 5d5ef704a5284..c23a16bca2730 100644 --- a/paddle/fluid/pybind/dist_static_op_function.h +++ b/paddle/fluid/pybind/dist_static_op_function.h @@ -18,6 +18,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/exception.h" +#include "paddle/fluid/pybind/pir.h" #include "paddle/phi/core/enforce.h" namespace paddle { @@ -66,12 +67,29 @@ static PyObject *static_api_reshard(PyObject *self, PyObject *process_mesh_obj = PyTuple_GET_ITEM(args, 1); auto process_mesh = CastPyArg2ProcessMesh(process_mesh_obj, 1); - PyObject *dims_mapping_obj = PyTuple_GET_ITEM(args, 2); - auto dims_mapping = CastPyArg2VectorOfInt64(dims_mapping_obj, 2); - - PyObject *placements_obj = PyTuple_GET_ITEM(args, 3); - auto placements = CastPyArg2VectorOfPlacement(placements_obj, 3); - + PyObject *placements_obj = PyTuple_GET_ITEM(args, 2); + auto placements = CastPyArg2VectorOfPlacement(placements_obj, 2); + + int64_t ndim = GetValueDims(input).size(); + std::vector dim_map(ndim, -1); + for (size_t i = 0; i < placements.size(); i++) { + auto &placement = placements[i]; + if (placement->is_shard()) { + auto shard_dim = + dynamic_cast(*placement).get_dim(); + PADDLE_ENFORCE_EQ( + dim_map[shard_dim], + -1, + common::errors::InvalidArgument( + "Tensor dim %lld is already sharded on mesh dim %lld," + " DistTensor operator implementation does not support things " + "like hybrid" + " sharding strategies yet (i.e. [Shard(0), Shard(0)])", + shard_dim, + dim_map[shard_dim])); + dim_map[shard_dim] = i; + } + } paddle::flat_hash_map partial_status; for (size_t i = 0; i < placements.size(); ++i) { auto &p = placements[i]; @@ -83,8 +101,8 @@ static PyObject *static_api_reshard(PyObject *self, } // Call ir static api - auto static_api_out = paddle::dialect::reshard( - input, process_mesh, dims_mapping, partial_status); + auto static_api_out = + paddle::dialect::reshard(input, process_mesh, dim_map, partial_status); return ToPyObject(static_api_out); } catch (...) { diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 4c675946373a6..ce35af4404e5a 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -35,7 +35,6 @@ shard_tensor as shard_tensor_static, ) from paddle.distributed.auto_parallel.placement_type import ( - to_dim_map, to_placements, ) from paddle.distributed.auto_parallel.static.completion import ( @@ -393,8 +392,7 @@ def reshard(dist_tensor, mesh, placements): return paddle.base.core.reshard(dist_tensor, dist_attr) elif in_pir_mode(): - dim_map = to_dim_map(placements, dist_tensor.ndim) - return paddle._pir_ops.reshard(dist_tensor, mesh, dim_map, placements) + return paddle._pir_ops.reshard(dist_tensor, mesh, placements) else: assert isinstance( dist_tensor, Variable From 438108cabf1026b671d6b1596ac55149b4732a25 Mon Sep 17 00:00:00 2001 From: winter-wang <1030748926@qq.com> Date: Thu, 11 Apr 2024 14:31:53 +0800 Subject: [PATCH 3/3] fix pr comment. --- python/paddle/distributed/auto_parallel/api.py | 2 +- test/auto_parallel/pir/test_reshard.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index ce35af4404e5a..49a8bd4b30e3e 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -392,7 +392,7 @@ def reshard(dist_tensor, mesh, placements): return paddle.base.core.reshard(dist_tensor, dist_attr) elif in_pir_mode(): - return paddle._pir_ops.reshard(dist_tensor, mesh, placements) + return paddle._C_ops.reshard(dist_tensor, mesh, placements) else: assert isinstance( dist_tensor, Variable diff --git a/test/auto_parallel/pir/test_reshard.py b/test/auto_parallel/pir/test_reshard.py index dbde9b8865a3a..0163b4f853e2d 100644 --- a/test/auto_parallel/pir/test_reshard.py +++ b/test/auto_parallel/pir/test_reshard.py @@ -23,7 +23,7 @@ sys.path.append(str(pathlib.Path(__file__).resolve().parents[0])) from test_to_static_pir_program import ( BATCH_SIZE, - IMAGE_SIZE, + CLASS_NUM, DemoNet, create_data_loader, ) @@ -77,7 +77,7 @@ def test_to_static_program(self): ) self.assertEqual( self.fwd_input._local_shape, - [BATCH_SIZE, IMAGE_SIZE // 2], + [BATCH_SIZE, CLASS_NUM], ) self.fwd_output = op.result(0) self.assertEqual( @@ -88,7 +88,7 @@ def test_to_static_program(self): ) self.assertEqual( self.fwd_output._local_shape, - [BATCH_SIZE / 2, IMAGE_SIZE // 2], + [BATCH_SIZE / 2, CLASS_NUM], ) elif index == 1: # backward reshard op