@@ -79,42 +79,43 @@ def test_eval(self):
79
79
)
80
80
81
81
82
- # class TestGatherAxisPosSymbolic(unittest.TestCase):
83
- # def setUp(self):
84
- # paddle.seed(2022)
85
- # self.prepare_data()
86
- #
87
- # def prepare_data(self):
88
- # self.shape = [None, 4 ]
89
- # self.x = paddle.randn(self.shape, dtype="float32")
90
- # self.x.stop_gradient = True
91
- # self.index = paddle.to_tensor([1])
92
- # self.index.stop_gradient = True
93
- #
94
- # def check_jit_kernel_info(self, static_fn):
95
- # utils.check_jit_kernel_number(static_fn, 1)
96
- # utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})
97
- #
98
- # def eval(self, use_cinn):
99
- # net = GatherLayerAxisPos()
100
- # input_spec = [
101
- # InputSpec(shape=[None, 4], dtype='float32'),
102
- # InputSpec(shape=[1], dtype='int32'),
103
- # ]
104
- # net = utils.apply_to_static(net, use_cinn, input_spec)
105
- # net.eval()
106
- # out = net(self.x, self.index)
107
- # if use_cinn:
108
- # self.check_jit_kernel_info(net.forward)
109
- # return out
110
- #
111
- # def test_eval(self):
112
- # cinn_out = self.eval(use_cinn=True)
113
- # dy_out = self.eval(use_cinn=False)
114
- # np.testing.assert_allclose(
115
- # cinn_out.numpy(), dy_out.numpy(), atol=1e-6, rtol=1e-6
116
- # )
117
- #
82
+ class TestGatherAxisPosSymbolic (unittest .TestCase ):
83
+ def setUp (self ):
84
+ paddle .seed (2022 )
85
+ self .prepare_data ()
86
+
87
+ def prepare_data (self ):
88
+ self .shape = [32 , 4 ]
89
+ self .x = paddle .randn (self .shape , dtype = "float32" )
90
+ self .x .stop_gradient = True
91
+ self .index = paddle .to_tensor ([1 ])
92
+ self .index .stop_gradient = True
93
+
94
+ def check_jit_kernel_info (self , static_fn ):
95
+ utils .check_jit_kernel_number (static_fn , 1 )
96
+ utils .check_jit_kernel_structure (static_fn , {utils .JIT_KERNEL_NAME : 1 })
97
+
98
+ def eval (self , use_cinn ):
99
+ net = GatherLayerAxisPos ()
100
+ input_spec = [
101
+ InputSpec (shape = [None , 4 ], dtype = 'float32' ),
102
+ InputSpec (shape = [1 ], dtype = 'int32' ),
103
+ ]
104
+ net = utils .apply_to_static (net , use_cinn , input_spec )
105
+ net .eval ()
106
+ out = net (self .x , self .index )
107
+ if use_cinn :
108
+ self .check_jit_kernel_info (net .forward )
109
+ return out
110
+
111
+ def test_eval (self ):
112
+ cinn_out = self .eval (use_cinn = True )
113
+ dy_out = self .eval (use_cinn = False )
114
+ np .testing .assert_allclose (
115
+ cinn_out .numpy (), dy_out .numpy (), atol = 1e-6 , rtol = 1e-6
116
+ )
117
+
118
+
118
119
class TestGatherAxisNegStatic (unittest .TestCase ):
119
120
def setUp (self ):
120
121
paddle .seed (2022 )
0 commit comments