Skip to content

Commit

Permalink
[Prim][PIR] support unbind op forward in prim pir (#64430)
Browse files Browse the repository at this point in the history
* update unbind

* fix size_t

* update dynamic test

* update unbind

* add assert

* Update test_unbind_op.py

* prim test change

* fix code
  • Loading branch information
Eddie-Wang1120 authored Jun 18, 2024
1 parent 7725a4d commit b739ff0
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"squeeze",
"stack",
"unsqueeze",
"unbind",
"huber_loss",
]

Expand Down Expand Up @@ -103,6 +104,7 @@
"squeeze",
"stack",
"unsqueeze",
"unbind",
"huber_loss",
]

Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,23 @@ std::vector<Tensor> meshgrid_decomp(const std::vector<Tensor>& x) {
return res;
}

template <typename T>
std::vector<Tensor> unbind_decomp(const Tensor x, int axis) {
std::vector<Tensor> res;
if (axis < 0) {
axis = x.shape().size() + axis;
}
if (x.shape()[axis] == -1) {
PADDLE_THROW(phi::errors::Unimplemented("unbind axis must not be dynamic"));
}
size_t num = x.shape()[axis];
std::vector<Tensor> tmp = backend::split_with_num<T>(x, num, axis);
for (size_t i = 0; i < tmp.size(); i++) {
res.push_back(squeeze<T>(tmp[i], {axis}));
}
return res;
}

template <typename T>
std::tuple<Tensor, Tensor, Tensor> layer_norm_decomp(
const Tensor& x,
Expand Down
40 changes: 36 additions & 4 deletions test/deprecated/legacy_test/test_unbind_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def setAxis(self):

def setUp(self):
self._set_op_type()
self.prim_op_type = "comp"
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
Expand All @@ -186,6 +187,7 @@ def setUp(self):
'Out': [('out%d' % i, self.out[i]) for i in range(len(self.out))]
}
self.python_api = paddle.unbind
self.public_python_api = paddle.unbind
self.python_out_sig = ['out%d' % i for i in range(len(self.out))]

def get_dtype(self):
Expand All @@ -195,10 +197,12 @@ def _set_op_type(self):
self.op_type = "unbind"

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], ['out0', 'out1', 'out2'], check_pir=True)
self.check_grad(
['X'], ['out0', 'out1', 'out2'], check_pir=True, check_prim_pir=True
)


class TestUnbindOp1(TestUnbindOp):
Expand Down Expand Up @@ -263,47 +267,73 @@ class TestUnbindOp1_Complex64(TestUnbindOp1):
def get_dtype(self):
return np.complex64

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp2_Complex64(TestUnbindOp2):
def get_dtype(self):
return np.complex64

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp3_Complex64(TestUnbindOp3):
def get_dtype(self):
return np.complex64

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp4_Complex64(TestUnbindOp4):
def get_dtype(self):
return np.complex64

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp1_Complex128(TestUnbindOp1):
def get_dtype(self):
return np.complex128

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp2_Complex128(TestUnbindOp2):
def get_dtype(self):
return np.complex128

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp3_Complex128(TestUnbindOp3):
def get_dtype(self):
return np.complex128

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp4_Complex128(TestUnbindOp4):
def get_dtype(self):
return np.complex128

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindFP16Op(OpTest):
def setUp(self):
paddle.disable_static()
self.op_type = "unbind"
self.prim_op_type = "comp"
self.python_api = paddle.unbind
self.public_python_api = paddle.unbind
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
Expand All @@ -326,14 +356,16 @@ def get_dtype(self):
return np.float16

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)


class TestUnbindBF16Op(OpTest):
def setUp(self):
paddle.disable_static()
self._set_op_type()
self.prim_op_type = "comp"
self.python_api = paddle.unbind
self.public_python_api = paddle.unbind
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
Expand Down Expand Up @@ -362,7 +394,7 @@ def _set_op_type(self):
self.op_type = "unbind"

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
pass
Expand Down
3 changes: 3 additions & 0 deletions test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import paddle
import paddle.nn.functional as F
from paddle.framework import core
from paddle.static import InputSpec

sys.path.append(dirname(dirname(__file__)))
Expand Down Expand Up @@ -810,6 +811,8 @@ def prepare_data(self):
]

def test_eval_symbolic(self):
core._set_prim_forward_blacklist("pd_op.unbind")

net = UnbindNet()

for i in range(len(self.cases)):
Expand Down
17 changes: 17 additions & 0 deletions test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ def meshgrid_net(x, y):
return paddle.meshgrid(x, y)


def unbind_net(x):
return paddle.unbind(x)


class TestPrimBase(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
Expand Down Expand Up @@ -247,6 +251,19 @@ def setUp(self):
self.tol = 1e-6


class TestUnbind(TestPrimBase):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [4, 5, 6]
self.init_x_shape = [4, 5, None]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = unbind_net
self.necessary_ops = "pd_op.unbind"
self.enable_cinn = False
self.tol = 1e-6


class TestPrimFullLike(TestPrimBase):
def setUp(self):
np.random.seed(2023)
Expand Down

0 comments on commit b739ff0

Please sign in to comment.