Skip to content

Commit

Permalink
[Fix] fix slice negative axis (PaddlePaddle#59272)
Browse files Browse the repository at this point in the history
  • Loading branch information
megemini authored and SecretXV committed Nov 28, 2023
1 parent 81d03c6 commit f93f63f
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 17 deletions.
32 changes: 15 additions & 17 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,28 +313,26 @@ def slice(input, axes, starts, ends):
>>> sliced_2 = paddle.slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends)
>>> # sliced_2 is input[1:3, 0:2, 2:4].
"""
if isinstance(axes, (list, tuple)):
axes = list(axes)
if len(axes) == 0:
raise ValueError("Input axes should not be an empty list/tuple.")
for i in range(len(axes)):
if axes[i] < 0:
axes[i] = max(0, axes[i] + len(input.shape))
else:
axes[i] = min(len(input.shape) - 1, axes[i])

else:
raise ValueError(
f"Input axes must be a python list or tuple, but reveived {type(axes)}"
)

if in_dynamic_mode():
attrs = ()
starts_tensor = None
ends_tensor = None

if isinstance(axes, (list, tuple)):
axes = list(axes)
if len(axes) == 0:
raise ValueError(
"Input axes should not be an empty list/tuple."
)
for i in range(len(axes)):
if axes[i] < 0:
axes[i] = max(0, axes[i] + len(input.shape))
else:
axes[i] = min(len(input.shape) - 1, axes[i])

else:
raise ValueError(
f"Input axes must be a python list or tuple, but reveived {type(axes)}"
)

infer_flags = [1 for i in range(len(axes))]

if isinstance(starts, (list, tuple)):
Expand Down
67 changes: 67 additions & 0 deletions test/legacy_test/test_slice_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,73 @@ def test_pir(self):
np.testing.assert_array_equal(res_6, input[-3:3, 0:100, :, 2:-1])
# np.testing.assert_array_equal(res_7, input[-1, 0:100, :, 2:-1])

# Test negative axis
def test_negative_axis_dygraph(self):
with paddle.base.dygraph.guard():
input = np.random.random([3, 4, 5, 6]).astype("float64")

res = paddle.slice(
paddle.to_tensor(input), axes=[-2], starts=[2], ends=[3]
)
np.testing.assert_array_equal(res, input[:, :, 2:3, :])

def test_negative_axis_static(self):
with paddle_static_guard(), paddle.static.program_guard(
paddle.static.Program()
):
input = np.random.random([3, 4, 5, 6]).astype("float64")
x = paddle.static.data(
name="x",
shape=[3, 4, 5, 6],
dtype="float64",
)

out = paddle.slice(
x,
axes=[-2],
starts=[2],
ends=[3],
)

exe = base.Executor(place=base.CPUPlace())
res = exe.run(
feed={
"x": input,
},
fetch_list=[out],
)[0]

np.testing.assert_array_equal(res, input[:, :, 2:3, :])

def test_negative_axis_pir(self):
with paddle.pir_utils.IrGuard(), paddle.static.program_guard(
paddle.static.Program()
):
input = np.random.random([3, 4, 5, 6]).astype("float64")
x = paddle.static.data(
name="x",
shape=[3, 4, 5, 6],
dtype="float64",
)

out = paddle.slice(
x,
axes=[-2],
starts=[2],
ends=[3],
)

exe = base.Executor(place=base.CPUPlace())
res = exe.run(
paddle.static.default_main_program(),
feed={
"x": input,
},
fetch_list=[out],
)[0]

np.testing.assert_array_equal(res, input[:, :, 2:3, :])


class TestSliceApiWithTensor(unittest.TestCase):
def test_starts_ends_is_tensor(self):
Expand Down

0 comments on commit f93f63f

Please sign in to comment.