Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][PIR] support unbind op forward in prim pir #64430

Merged
merged 10 commits into from
Jun 18, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"squeeze",
"stack",
"unsqueeze",
"unbind",
"huber_loss",
]

Expand Down Expand Up @@ -96,6 +97,7 @@
"squeeze",
"stack",
"unsqueeze",
"unbind",
"huber_loss",
]

Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,23 @@ std::vector<Tensor> meshgrid_decomp(const std::vector<Tensor>& x) {
return res;
}

template <typename T>
std::vector<Tensor> unbind_decomp(const Tensor x, int axis) {
std::vector<Tensor> res;
if (axis < 0) {
axis = x.shape().size() + axis;
}
if (x.shape()[axis] == -1) {
PADDLE_THROW(phi::errors::Unimplemented("unbind axis must not be dynamic"));
}
size_t num = x.shape()[axis];
std::vector<Tensor> tmp = backend::split_with_num<T>(x, num, axis);
for (size_t i = 0; i < tmp.size(); i++) {
res.push_back(squeeze<T>(tmp[i], {axis}));
}
return res;
}

template <typename T>
std::tuple<Tensor, Tensor, Tensor> layer_norm_decomp(
const Tensor& x,
Expand Down
36 changes: 33 additions & 3 deletions test/deprecated/legacy_test/test_unbind_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def setAxis(self):

def setUp(self):
self._set_op_type()
self.prim_op_type = "comp"
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
Expand All @@ -186,6 +187,7 @@ def setUp(self):
'Out': [('out%d' % i, self.out[i]) for i in range(len(self.out))]
}
self.python_api = paddle.unbind
self.public_python_api = paddle.unbind
self.python_out_sig = ['out%d' % i for i in range(len(self.out))]

def get_dtype(self):
Expand All @@ -195,7 +197,7 @@ def _set_op_type(self):
self.op_type = "unbind"

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], ['out0', 'out1', 'out2'], check_pir=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

, check_prim_pir=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Expand Down Expand Up @@ -263,47 +265,73 @@ class TestUnbindOp1_Complex64(TestUnbindOp1):
def get_dtype(self):
return np.complex64

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp2_Complex64(TestUnbindOp2):
def get_dtype(self):
return np.complex64

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp3_Complex64(TestUnbindOp3):
def get_dtype(self):
return np.complex64

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp4_Complex64(TestUnbindOp4):
def get_dtype(self):
return np.complex64

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp1_Complex128(TestUnbindOp1):
def get_dtype(self):
return np.complex128

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp2_Complex128(TestUnbindOp2):
def get_dtype(self):
return np.complex128

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp3_Complex128(TestUnbindOp3):
def get_dtype(self):
return np.complex128

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp4_Complex128(TestUnbindOp4):
def get_dtype(self):
return np.complex128

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindFP16Op(OpTest):
def setUp(self):
paddle.disable_static()
self.op_type = "unbind"
self.prim_op_type = "comp"
self.python_api = paddle.unbind
self.public_python_api = paddle.unbind
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
Expand All @@ -326,14 +354,16 @@ def get_dtype(self):
return np.float16

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)


class TestUnbindBF16Op(OpTest):
def setUp(self):
paddle.disable_static()
self._set_op_type()
self.prim_op_type = "comp"
self.python_api = paddle.unbind
self.public_python_api = paddle.unbind
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
Expand Down Expand Up @@ -362,7 +392,7 @@ def _set_op_type(self):
self.op_type = "unbind"

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
pass
Expand Down
17 changes: 17 additions & 0 deletions test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def meshgrid_net(x, y):
return paddle.meshgrid(x, y)


def unbind_net(x):
return paddle.unbind(x)


class TestPrimBase(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
Expand Down Expand Up @@ -231,6 +235,19 @@ def setUp(self):
self.tol = 1e-6


class TestUnbind(TestPrimBase):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [4, 5, 6]
self.init_x_shape = [4, 5, None]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = unbind_net
self.necessary_ops = "pd_op.unbind"
self.enable_cinn = False
self.tol = 1e-6


class TestPrimFullLike(TestPrimBase):
def setUp(self):
np.random.seed(2023)
Expand Down