Skip to content

Commit

Permalink
add dense to dist inplace api (#62014)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio authored Feb 26, 2024
1 parent 229d945 commit 044dfe1
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 5 deletions.
25 changes: 25 additions & 0 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,27 @@ static PyObject* tensor__zero_grads(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor__to_dist(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
const auto& placements =
CastPyArg2VectorOfPlacement(PyTuple_GET_ITEM(args, 0), 0);
const auto& mesh = CastPyArg2ProcessMesh(PyTuple_GET_ITEM(args, 1), 1);

if (self->tensor.is_dense_tensor()) {
const auto& dense_tensor_ptr =
std::static_pointer_cast<phi::DenseTensor>(self->tensor.impl());
auto dist_tensor_ptr = std::make_shared<phi::distributed::DistTensor>(
dense_tensor_ptr, mesh, placements);
self->tensor.set_impl(dist_tensor_ptr);
}

RETURN_PY_NONE

EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor__share_buffer_to(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
Expand Down Expand Up @@ -3218,6 +3239,10 @@ PyMethodDef variable_methods[] = { // NOLINT
(PyCFunction)(void (*)())tensor__zero_grads,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_to_dist_",
(PyCFunction)(void (*)())tensor__to_dist,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_share_buffer_to",
(PyCFunction)(void (*)())tensor__share_buffer_to,
METH_VARARGS | METH_KEYWORDS,
Expand Down
6 changes: 1 addition & 5 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,15 +864,11 @@ def __init__(self, mesh):
self._mesh = mesh

def _shard_parameter(self, param):
# TODO(liyurui): remove this trick dense to dist convert after adding
# dense_tensor.to_dist method.
if param.is_dense():
zero_dense = paddle.zeros(param.shape)
placements = []
for _ in range(len(self._mesh.shape)):
placements.append(dist.Replicate())
zero_dist = dist.shard_tensor(zero_dense, self._mesh, placements)
res = param + zero_dist
param._to_dist_(placements, self._mesh)

new_placements = get_placement_with_sharding(param)
shard_param = dist.reshard(param, param.process_mesh, new_placements)
Expand Down
50 changes: 50 additions & 0 deletions test/auto_parallel/semi_dense_tensor_to_dist_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

import numpy as np

import paddle
import paddle.distributed as dist


class TestDenseTensorToDistAPI(unittest.TestCase):
def setUp(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seed = 2023
self._backend = os.getenv("backend")
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
paddle.seed(self._seed)
np.random.seed(self._seed)

def run_test_dense_tensor_to_dist_api(self):
if self._backend == "cpu":
paddle.set_device("cpu")
place = paddle.CPUPlace()
elif self._backend == "gpu":
place = paddle.CUDAPlace(dist.get_rank())

dense_dist_tensor = paddle.rand([4, 10])
dense_dist_tensor._to_dist_([dist.Replicate()], self._mesh)
assert dense_dist_tensor.is_dist()

def test_case(self):
self.run_test_dense_tensor_to_dist_api()


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions test/auto_parallel/test_dist_tensor_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ def test_dtensor_from_local_api(self):
user_defined_envs=envs,
)

def test_dense_tensor_to_dist_api(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"semi_dense_tensor_to_dist_api.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 044dfe1

Please sign in to comment.