Skip to content

Commit

Permalink
add strided_slice_grad op for npu
Browse files Browse the repository at this point in the history
  • Loading branch information
baoachun committed Aug 28, 2021
1 parent f802bc6 commit 70fa9ab
Showing 1 changed file with 53 additions and 12 deletions.
65 changes: 53 additions & 12 deletions paddle/fluid/operators/strided_slice_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ class StridedSliceGradNPUKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<framework::Tensor>("Input");
auto input_dims = input->dims();
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));

auto* dx = ctx.Output<Tensor>(framework::GradVarName("Input"));
dx->mutable_data<T>(input_dims, place);

Expand Down Expand Up @@ -317,13 +316,19 @@ class StridedSliceGradNPUKernel : public framework::OpKernel<T> {
strides = GetDataFromTensor<int64_t>(strides_tensor);
}

std::vector<int64_t> out_dims_vector(input_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, input_dims,
decrease_axis, out_dims_vector.data(), axes.size(),
false);

std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), input_dims, infer_flags,
decrease_axis, starts.size());

std::vector<int64_t> starts_indices_vector(D, 0);
std::vector<int64_t> ends_indices_vector(D, 0);
std::vector<int64_t> ends_indices_vector(out_dims_vector.begin(),
out_dims_vector.end());
std::vector<int64_t> strides_indices_vector(D, 1);

for (size_t axis = 0; axis < axes.size(); axis++) {
Expand Down Expand Up @@ -352,17 +357,53 @@ class StridedSliceGradNPUKernel : public framework::OpKernel<T> {
Tensor input_dims_tensor;
TensorFromVector(input_dims_vector, dev_ctx, &input_dims_tensor);

bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}

auto stream = dev_ctx.stream();
const auto& runner =
NpuOpRunner("StridedSliceGrad",
{input_dims_tensor, starts_indices_tensor,
ends_indices_tensor, strides_indices_tensor, *dout},
{*dx}, {{"begin_mask", 0},
{"end_mask", 0},
{"ellipsis_mask", 0},
{"new_axis_mask", 0},
{"shrink_axis_mask", 0}});
runner.Run(stream);
framework::NPUAttributeMap attr_input = {{"begin_mask", 0},
{"end_mask", 0},
{"ellipsis_mask", 0},
{"new_axis_mask", 0},
{"shrink_axis_mask", 0}};

if (need_reverse) {
Tensor reverse_axis;
std::vector<int> reverse_axis_vector;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
reverse_axis_vector.push_back(axes[axis]);
}
}
reverse_axis.mutable_data<int>(
{static_cast<int>(reverse_axis_vector.size())}, place);
TensorFromVector(reverse_axis_vector, dev_ctx, &reverse_axis);

Tensor dout_tmp;
dout_tmp.mutable_data<T>(dout->dims(), place);
const auto& runner_reverse =
NpuOpRunner("ReverseV2", {*dout, reverse_axis}, {dout_tmp});
runner_reverse.Run(stream);

const auto& runner =
NpuOpRunner("StridedSliceGrad",
{input_dims_tensor, starts_indices_tensor,
ends_indices_tensor, strides_indices_tensor, dout_tmp},
{*dx}, attr_input);
runner.Run(stream);
} else {
const auto& runner =
NpuOpRunner("StridedSliceGrad",
{input_dims_tensor, starts_indices_tensor,
ends_indices_tensor, strides_indices_tensor, *dout},
{*dx}, attr_input);
runner.Run(stream);
}
}
};

Expand Down

1 comment on commit 70fa9ab

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.