Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][PIR] support unbind op forward in prim pir #64430

Merged
merged 10 commits into from
Jun 18, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"squeeze",
"stack",
"unsqueeze",
"unbind",
"huber_loss",
]

Expand Down Expand Up @@ -96,6 +97,7 @@
"squeeze",
"stack",
"unsqueeze",
"unbind",
"huber_loss",
]

Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,23 @@ std::vector<Tensor> meshgrid_decomp(const std::vector<Tensor>& x) {
return res;
}

template <typename T>
std::vector<Tensor> unbind_decomp(const Tensor x, int axis) {
std::vector<Tensor> res;
if (axis < 0) {
axis = x.shape().size() + axis;
}
if (x.shape()[axis] == -1) {
PADDLE_THROW(phi::errors::Unimplemented("unbind axis must not be dynamic"));
}
size_t num = x.shape()[axis];
std::vector<Tensor> tmp = backend::split_with_num<T>(x, num, axis);
for (size_t i = 0; i < tmp.size(); i++) {
res.push_back(squeeze<T>(tmp[i], {axis}));
}
return res;
}

template <typename T>
std::tuple<Tensor, Tensor, Tensor> layer_norm_decomp(
const Tensor& x,
Expand Down
40 changes: 36 additions & 4 deletions test/deprecated/legacy_test/test_unbind_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def setAxis(self):

def setUp(self):
self._set_op_type()
self.prim_op_type = "comp"
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
Expand All @@ -186,6 +187,7 @@ def setUp(self):
'Out': [('out%d' % i, self.out[i]) for i in range(len(self.out))]
}
self.python_api = paddle.unbind
self.public_python_api = paddle.unbind
self.python_out_sig = ['out%d' % i for i in range(len(self.out))]

def get_dtype(self):
Expand All @@ -195,10 +197,12 @@ def _set_op_type(self):
self.op_type = "unbind"

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], ['out0', 'out1', 'out2'], check_pir=True)
self.check_grad(
['X'], ['out0', 'out1', 'out2'], check_pir=True, check_prim_pir=True
)


class TestUnbindOp1(TestUnbindOp):
Expand Down Expand Up @@ -263,47 +267,73 @@ class TestUnbindOp1_Complex64(TestUnbindOp1):
def get_dtype(self):
return np.complex64

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp2_Complex64(TestUnbindOp2):
def get_dtype(self):
return np.complex64

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp3_Complex64(TestUnbindOp3):
def get_dtype(self):
return np.complex64

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp4_Complex64(TestUnbindOp4):
def get_dtype(self):
return np.complex64

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp1_Complex128(TestUnbindOp1):
def get_dtype(self):
return np.complex128

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp2_Complex128(TestUnbindOp2):
def get_dtype(self):
return np.complex128

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp3_Complex128(TestUnbindOp3):
def get_dtype(self):
return np.complex128

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindOp4_Complex128(TestUnbindOp4):
def get_dtype(self):
return np.complex128

def test_check_output(self):
self.check_output(check_pir=True)


class TestUnbindFP16Op(OpTest):
def setUp(self):
paddle.disable_static()
self.op_type = "unbind"
self.prim_op_type = "comp"
self.python_api = paddle.unbind
self.public_python_api = paddle.unbind
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
Expand All @@ -326,14 +356,16 @@ def get_dtype(self):
return np.float16

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)


class TestUnbindBF16Op(OpTest):
def setUp(self):
paddle.disable_static()
self._set_op_type()
self.prim_op_type = "comp"
self.python_api = paddle.unbind
self.public_python_api = paddle.unbind
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
Expand Down Expand Up @@ -362,7 +394,7 @@ def _set_op_type(self):
self.op_type = "unbind"

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
pass
Expand Down
3 changes: 3 additions & 0 deletions test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import paddle
import paddle.nn.functional as F
from paddle.framework import core
from paddle.static import InputSpec

sys.path.append(dirname(dirname(__file__)))
Expand Down Expand Up @@ -807,6 +808,8 @@ def prepare_data(self):
]

def test_eval_symbolic(self):
core._set_prim_forward_blacklist("pd_op.unbind")

net = UnbindNet()

for i in range(len(self.cases)):
Expand Down
17 changes: 17 additions & 0 deletions test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def meshgrid_net(x, y):
return paddle.meshgrid(x, y)


def unbind_net(x):
return paddle.unbind(x)


class TestPrimBase(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
Expand Down Expand Up @@ -231,6 +235,19 @@ def setUp(self):
self.tol = 1e-6


class TestUnbind(TestPrimBase):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [4, 5, 6]
self.init_x_shape = [4, 5, None]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = unbind_net
self.necessary_ops = "pd_op.unbind"
self.enable_cinn = False
self.tol = 1e-6


class TestPrimFullLike(TestPrimBase):
def setUp(self):
np.random.seed(2023)
Expand Down