Skip to content

Commit d452a3f

Browse files
authored
[CINN / Symbolic]Fix gather infer symbolic bugs (#63973)
* fix gather infersymbolic * fix * fix * fix
1 parent ef62c70 commit d452a3f

File tree

2 files changed

+53
-40
lines changed

2 files changed

+53
-40
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

+16-4
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,22 @@ bool GatherOpInferSymbolicShape(
235235
return numel;
236236
}();
237237

238-
const auto &axis_shape_or_data =
239-
shape_analysis->GetShapeOrDataForValue(op->operand_source(2));
238+
int axis = 0;
239+
const auto &attributes = op->attributes();
240+
if (op->HasAttribute("axis")) { // CINN Dialect
241+
axis = attributes.at("axis").dyn_cast<pir::Int32Attribute>().data();
242+
} else {
243+
PADDLE_ENFORCE_EQ(
244+
op->num_operands() == 3,
245+
true,
246+
phi::errors::InvalidArgument(
247+
"in GatherOpInferSymbolicShape: The number of operands should be "
248+
"3 when the axis is not set."));
249+
const auto &axis_shape_or_data =
250+
shape_analysis->GetShapeOrDataForValue(op->operand_source(2));
251+
axis =
252+
static_cast<int>(axis_shape_or_data.data().value()[0].Get<int64_t>());
253+
}
240254

241255
const std::vector<symbol::DimExpr> &input_sym_shape =
242256
input_shape_or_data.data().has_value()
@@ -248,8 +262,6 @@ bool GatherOpInferSymbolicShape(
248262
? index_shape_or_data.data().value()
249263
: index_shape_or_data.shape();
250264

251-
int axis =
252-
static_cast<int>(axis_shape_or_data.data().value()[0].Get<int64_t>());
253265
if (axis < 0) axis += input_sym_shape.size();
254266

255267
const auto &out_sym_shape = [&] {

test/ir/pir/cinn/symbolic/test_cinn_transform_symbolic.py

+37-36
Original file line numberDiff line numberDiff line change
@@ -79,42 +79,43 @@ def test_eval(self):
7979
)
8080

8181

82-
# class TestGatherAxisPosSymbolic(unittest.TestCase):
83-
# def setUp(self):
84-
# paddle.seed(2022)
85-
# self.prepare_data()
86-
#
87-
# def prepare_data(self):
88-
# self.shape = [None, 4 ]
89-
# self.x = paddle.randn(self.shape, dtype="float32")
90-
# self.x.stop_gradient = True
91-
# self.index = paddle.to_tensor([1])
92-
# self.index.stop_gradient = True
93-
#
94-
# def check_jit_kernel_info(self, static_fn):
95-
# utils.check_jit_kernel_number(static_fn, 1)
96-
# utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})
97-
#
98-
# def eval(self, use_cinn):
99-
# net = GatherLayerAxisPos()
100-
# input_spec = [
101-
# InputSpec(shape=[None, 4], dtype='float32'),
102-
# InputSpec(shape=[1], dtype='int32'),
103-
# ]
104-
# net = utils.apply_to_static(net, use_cinn, input_spec)
105-
# net.eval()
106-
# out = net(self.x, self.index)
107-
# if use_cinn:
108-
# self.check_jit_kernel_info(net.forward)
109-
# return out
110-
#
111-
# def test_eval(self):
112-
# cinn_out = self.eval(use_cinn=True)
113-
# dy_out = self.eval(use_cinn=False)
114-
# np.testing.assert_allclose(
115-
# cinn_out.numpy(), dy_out.numpy(), atol=1e-6, rtol=1e-6
116-
# )
117-
#
82+
class TestGatherAxisPosSymbolic(unittest.TestCase):
83+
def setUp(self):
84+
paddle.seed(2022)
85+
self.prepare_data()
86+
87+
def prepare_data(self):
88+
self.shape = [32, 4]
89+
self.x = paddle.randn(self.shape, dtype="float32")
90+
self.x.stop_gradient = True
91+
self.index = paddle.to_tensor([1])
92+
self.index.stop_gradient = True
93+
94+
def check_jit_kernel_info(self, static_fn):
95+
utils.check_jit_kernel_number(static_fn, 1)
96+
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})
97+
98+
def eval(self, use_cinn):
99+
net = GatherLayerAxisPos()
100+
input_spec = [
101+
InputSpec(shape=[None, 4], dtype='float32'),
102+
InputSpec(shape=[1], dtype='int32'),
103+
]
104+
net = utils.apply_to_static(net, use_cinn, input_spec)
105+
net.eval()
106+
out = net(self.x, self.index)
107+
if use_cinn:
108+
self.check_jit_kernel_info(net.forward)
109+
return out
110+
111+
def test_eval(self):
112+
cinn_out = self.eval(use_cinn=True)
113+
dy_out = self.eval(use_cinn=False)
114+
np.testing.assert_allclose(
115+
cinn_out.numpy(), dy_out.numpy(), atol=1e-6, rtol=1e-6
116+
)
117+
118+
118119
class TestGatherAxisNegStatic(unittest.TestCase):
119120
def setUp(self):
120121
paddle.seed(2022)

0 commit comments

Comments
 (0)