Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… support_expand_as_v2_npu
  • Loading branch information
rainyfly committed Aug 6, 2021
2 parents f4bd351 + fa16c21 commit 4826669
Show file tree
Hide file tree
Showing 17 changed files with 958 additions and 213 deletions.
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ELSE ()
ENDIF()

SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210729")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210804")
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
Expand Down
44 changes: 44 additions & 0 deletions paddle/fluid/operators/activation_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,39 @@ class CosGradNPUKernel : public framework::OpKernel<T> {
}
};

template <typename DeviceContext, typename T>
class AtanNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
const auto& runner = NpuOpRunner("Atan", {*x}, {*out}, {});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};

template <typename DeviceContext, typename T>
class AtanGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<Tensor>("X");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto place = ctx.GetPlace();
dx->mutable_data<T>(place);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
const auto& runner_dx = NpuOpRunner("AtanGrad", {*x, *dout}, {*dx}, {});
runner_dx.Run(stream);
}
};

} // namespace operators
} // namespace paddle

Expand Down Expand Up @@ -648,3 +681,14 @@ REGISTER_OP_NPU_KERNEL(
cos_grad, ops::CosGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::CosGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);

REGISTER_OP_NPU_KERNEL(
atan, ops::AtanNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::AtanNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);

REGISTER_OP_NPU_KERNEL(
atan_grad,
ops::AtanGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::AtanGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
59 changes: 59 additions & 0 deletions paddle/fluid/operators/eye_op_npu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/eye_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class EyeNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto num_rows = ctx.Attr<int64_t>("num_rows");

auto d_nums = ctx.Attr<int>("dtype");
auto dtype =
ConvertToNpuDtype(static_cast<framework::proto::VarType::Type>(d_nums));

auto num_columns = ctx.Attr<int64_t>("num_columns");
if (num_columns == -1) num_columns = num_rows;

framework::NPUAttributeMap attr_input = {
{"num_rows", num_rows}, {"num_columns", num_columns}, {"dtype", dtype}};

auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());

const auto& runner = NpuOpRunner("Eye", {}, {*out}, attr_input);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_NPU_KERNEL(
eye, ops::EyeNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::EyeNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::EyeNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
101 changes: 101 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_prod_op_npu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the Licnse. */

#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class ReduceProdNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
auto dims = ctx.Attr<std::vector<int>>("dim");
bool keep_dim = ctx.Attr<bool>("keep_dim");
bool reduce_all = ctx.Attr<bool>("reduce_all");
int out_dtype = ctx.Attr<int>("out_dtype");

auto place = ctx.GetPlace();

framework::Tensor cast_out(x->type());
cast_out.Resize(out->dims());
cast_out.mutable_data<T>(place);

auto cast_out_dtype = x->type();
if (out_dtype != -1) {
cast_out_dtype = static_cast<framework::proto::VarType::Type>(out_dtype);
}

if (x->type() != cast_out_dtype) {
if (cast_out_dtype == framework::proto::VarType::FP32) {
out->mutable_data<float>(place);
} else if (cast_out_dtype == framework::proto::VarType::FP16) {
out->mutable_data<paddle::platform::float16>(place);
} else if (cast_out_dtype == framework::proto::VarType::INT16) {
out->mutable_data<int16_t>(place);
} else if (cast_out_dtype == framework::proto::VarType::INT32) {
out->mutable_data<int32_t>(place);
} else if (cast_out_dtype == framework::proto::VarType::INT64) {
out->mutable_data<int64_t>(place);
} else if (cast_out_dtype == framework::proto::VarType::FP64) {
out->mutable_data<double>(place);
} else if (cast_out_dtype == framework::proto::VarType::BOOL) {
out->mutable_data<bool>(place);
}
} else {
out->ShareDataWith(cast_out);
}

framework::NPUAttributeMap attr_input = {{"axes", dims},
{"keep_dims", keep_dim}};

if (reduce_all) {
std::vector<int> dim_vec;
for (int i = 0; i < x->dims().size(); i++) {
dim_vec.push_back(i);
}

attr_input = {{"axes", dim_vec}, {"keep_dims", keep_dim}};
}

auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();

const auto& runner =
NpuOpRunner("ReduceProdD", {*x}, {cast_out}, attr_input);
runner.Run(stream);

if (x->type() != cast_out_dtype) {
auto dst_dtype = ConvertToNpuDtype(cast_out_dtype);
const auto& runner_cast =
NpuOpRunner("Cast", {cast_out}, {*out},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast.Run(stream);
}
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
reduce_prod, ops::ReduceProdNPUKernel<plat::NPUDeviceContext, float>,
ops::ReduceProdNPUKernel<plat::NPUDeviceContext, plat::float16>);
6 changes: 6 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,13 @@ All parameter, weight, gradient are variables in Paddle.
.def("__repr__", string::to_string<const platform::XPUPlace &>)
.def("__str__", string::to_string<const platform::XPUPlace &>);
#ifdef PADDLE_WITH_XPU
py::enum_<platform::XPUVersion>(m, "XPUVersion", py::arithmetic())
.value("XPU1", platform::XPUVersion::XPU1)
.value("XPU2", platform::XPUVersion::XPU2)
.export_values();
m.def("get_xpu_device_count", platform::GetXPUDeviceCount);
m.def("get_xpu_device_version",
[](int device_id) { return platform::get_xpu_version(device_id); });
#endif

py::class_<paddle::platform::CPUPlace>(m, "CPUPlace", R"DOC(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,15 @@ def send_partial(tensor,
ring_id = 0 if group is None else group.id

if _is_valid_send_recv_partial(tensor, nranks):
return _C_ops.partial_send(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'peer', dst, 'num',
nranks, 'id', rank_id)
return _C_ops.partial_send(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id, 'peer',
dst, 'num', nranks, 'id', rank_id)
else:
return paddle.distributed.send(
tensor, dst=dst, group=group, use_calc_stream=use_calc_stream)
tensor.detach(),
dst=dst,
group=group,
use_calc_stream=use_calc_stream)


def recv_partial(tensor,
Expand All @@ -180,13 +183,16 @@ def recv_partial(tensor,
ring_id = 0 if group is None else group.id

if _is_valid_send_recv_partial(tensor, nranks):
_C_ops.partial_recv(tensor, 'use_calc_stream', use_calc_stream,
_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'peer', src, 'num', nranks,
'id', rank_id, 'dtype', tensor.dtype, 'out_shape',
tensor.shape)
else:
paddle.distributed.recv(
tensor, src=src, group=group, use_calc_stream=use_calc_stream)
tensor.detach(),
src=src,
group=group,
use_calc_stream=use_calc_stream)


def allgather_partial(tensor,
Expand All @@ -200,9 +206,9 @@ def allgather_partial(tensor,
return
ring_id = 0 if group is None else group.id

return _C_ops.partial_allgather_(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks,
'rank', rank_id)
return _C_ops.partial_allgather_(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id,
'nranks', nranks, 'rank', rank_id)


def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
Expand Down
16 changes: 8 additions & 8 deletions python/paddle/fluid/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,18 +286,18 @@ class ClipGradByNorm(ClipGradBase):
.. math::
Out =
\\left \{
\\begin{aligned}
& X & & if (norm(X) \\leq clip\_norm) \\\\
& \\frac{clip\_norm*X}{norm(X)} & & if (norm(X) > clip\_norm) \\\\
\\end{aligned}
\\right.
\left\{
\begin{array}{ccl}
X & & if (norm(X) \leq clip\_norm) \\
\frac{clip\_norm*X}{norm(X)} & & if (norm(X) > clip\_norm) \\
\end{array}
\right.
where :math:`norm(X)` represents the L2 norm of :math:`X`.
.. math::
norm(X) = ( \\sum_{i=1}^{n}|x\_i|^2)^{ \\frac{1}{2}}
norm(X) = ( \sum_{i=1}^{n}|x\_i|^2)^{ \frac{1}{2}}
Note:
``need_clip`` of ``ClipGradByNorm`` HAS BEEN DEPRECATED since 2.0.
Expand Down Expand Up @@ -389,7 +389,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
.. math::
t\_list[i] = t\_list[i] * \\frac{clip\_norm}{\max(global\_norm, clip\_norm)}
t\_list[i] = t\_list[i] * \frac{clip\_norm}{\max(global\_norm, clip\_norm)}
where:
Expand Down
Loading

1 comment on commit 4826669

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.