diff --git a/test/dygraph_to_static/test_basic_api_transformation.py b/test/dygraph_to_static/test_basic_api_transformation.py index e0998b8fe1e67f..51ddbe6e11a1cb 100644 --- a/test/dygraph_to_static/test_basic_api_transformation.py +++ b/test/dygraph_to_static/test_basic_api_transformation.py @@ -16,10 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import base, to_tensor @@ -72,8 +69,7 @@ def dyfunc_bool_to_tensor(x): return paddle.to_tensor(True) -@dy2static_unittest -class TestDygraphBasicApi_ToVariable(unittest.TestCase): +class TestDygraphBasicApi_ToVariable(Dy2StTestBase): def setUp(self): self.input = np.ones(5).astype("int32") self.test_funcs = [ @@ -96,7 +92,7 @@ def get_dygraph_output(self): res = self.dygraph_func(self.input).numpy() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): main_program = base.Program() main_program.random_seed = SEED @@ -234,8 +230,7 @@ def dyfunc_Prelu(input): return res -@dy2static_unittest -class TestDygraphBasicApi(unittest.TestCase): +class TestDygraphBasicApi(Dy2StTestBase): # Compare results of dynamic graph and transformed static graph function which only # includes basic Api. @@ -252,7 +247,7 @@ def get_dygraph_output(self): return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -286,7 +281,7 @@ def get_dygraph_output(self): res = self.dygraph_func(self.input1, self.input2).numpy() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -401,8 +396,7 @@ def dyfunc_PolynomialDecay(): return paddle.to_tensor(lr) -@dy2static_unittest -class TestDygraphBasicApi_CosineDecay(unittest.TestCase): +class TestDygraphBasicApi_CosineDecay(Dy2StTestBase): def setUp(self): self.dygraph_func = dyfunc_CosineDecay @@ -413,7 +407,7 @@ def get_dygraph_output(self): res = self.dygraph_func().numpy() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -444,7 +438,7 @@ def get_dygraph_output(self): res = self.dygraph_func() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -471,7 +465,7 @@ def get_dygraph_output(self): res = self.dygraph_func() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -498,7 +492,7 @@ def get_dygraph_output(self): res = self.dygraph_func() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -545,8 +539,7 @@ def _dygraph_fn(): np.random.random(1) -@dy2static_unittest -class TestDygraphApiRecognition(unittest.TestCase): +class TestDygraphApiRecognition(Dy2StTestBase): def setUp(self): self.src = inspect.getsource(_dygraph_fn) self.root = gast.parse(self.src) diff --git a/test/dygraph_to_static/test_bert.py b/test/dygraph_to_static/test_bert.py index ba8e2350794aad..7c6a2c1b4d42a4 100644 --- a/test/dygraph_to_static/test_bert.py +++ b/test/dygraph_to_static/test_bert.py @@ -20,10 +20,10 @@ import numpy as np from bert_dygraph_model import PretrainModelLayer from bert_utils import get_bert_config, get_feed_data_reader -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_pir_only, ) from predictor_utils import PredictorTools @@ -78,8 +78,7 @@ def __len__(self): return len(self.src_ids) -@dy2static_unittest -class TestBert(unittest.TestCase): +class TestBert(Dy2StTestBase): def setUp(self): self.bert_config = get_bert_config() self.data_reader = get_feed_data_reader(self.bert_config) @@ -266,7 +265,7 @@ def predict_analysis_inference(self, data): out = output() return out - @test_with_new_ir + @test_pir_only def test_train_new_ir(self): static_loss, static_ppl = self.train_static( self.bert_config, self.data_reader @@ -277,7 +276,7 @@ def test_train_new_ir(self): np.testing.assert_allclose(static_loss, dygraph_loss, rtol=1e-05) np.testing.assert_allclose(static_ppl, dygraph_ppl, rtol=1e-05) - @ast_only_test + @test_ast_only def test_train(self): static_loss, static_ppl = self.train_static( self.bert_config, self.data_reader diff --git a/test/dygraph_to_static/test_bmn.py b/test/dygraph_to_static/test_bmn.py index f5f8d357598695..11afe6100d79f2 100644 --- a/test/dygraph_to_static/test_bmn.py +++ b/test/dygraph_to_static/test_bmn.py @@ -18,7 +18,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest, test_with_new_ir +from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only from predictor_utils import PredictorTools import paddle @@ -637,8 +637,7 @@ def val_bmn(model, args): return loss_data -@dy2static_unittest -class TestTrain(unittest.TestCase): +class TestTrain(Dy2StTestBase): def setUp(self): self.args = Args() self.place = ( @@ -751,7 +750,7 @@ def train_bmn(self, args, place, to_static): break return np.array(loss_data) - @test_with_new_ir + @test_pir_only def test_train_new_ir(self): static_res = self.train_bmn(self.args, self.place, to_static=True) dygraph_res = self.train_bmn(self.args, self.place, to_static=False) diff --git a/test/dygraph_to_static/test_break_continue.py b/test/dygraph_to_static/test_break_continue.py index a803c1d4bf49ed..e1df868435e8fa 100644 --- a/test/dygraph_to_static/test_break_continue.py +++ b/test/dygraph_to_static/test_break_continue.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle from paddle import base @@ -26,14 +26,13 @@ np.random.seed(SEED) -@dy2static_unittest -class TestDy2staticException(unittest.TestCase): +class TestDy2staticException(Dy2StTestBase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = None self.error = "Your if/else have different number of return value." - @ast_only_test + @test_ast_only def test_error(self): if self.dyfunc: with self.assertRaisesRegex(Dygraph2StaticException, self.error): @@ -205,8 +204,7 @@ def test_optim_break_in_while(x): return x -@dy2static_unittest -class TestContinueInFor(unittest.TestCase): +class TestContinueInFor(Dy2StTestBase): def setUp(self): self.input = np.zeros(1).astype('int64') self.place = ( diff --git a/test/dygraph_to_static/test_build_strategy.py b/test/dygraph_to_static/test_build_strategy.py index 85e934afb020bb..ee19dad5842f9c 100644 --- a/test/dygraph_to_static/test_build_strategy.py +++ b/test/dygraph_to_static/test_build_strategy.py @@ -15,14 +15,13 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only from test_resnet import ResNetHelper import paddle -@dy2static_unittest -class TestResnetWithPass(unittest.TestCase): +class TestResnetWithPass(Dy2StTestBase): def setUp(self): self.build_strategy = paddle.static.BuildStrategy() self.build_strategy.fuse_elewise_add_act_ops = True @@ -62,7 +61,7 @@ def verify_predict(self): err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.', ) - @ast_only_test + @test_ast_only def test_resnet(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) @@ -74,7 +73,7 @@ def test_resnet(self): ) self.verify_predict() - @ast_only_test + @test_ast_only def test_in_static_mode_mkldnn(self): paddle.base.set_flags({'FLAGS_use_mkldnn': True}) try: @@ -84,8 +83,7 @@ def test_in_static_mode_mkldnn(self): paddle.base.set_flags({'FLAGS_use_mkldnn': False}) -@dy2static_unittest -class TestError(unittest.TestCase): +class TestError(Dy2StTestBase): def test_type_error(self): def foo(x): out = x + 1 diff --git a/test/dygraph_to_static/test_cache_program.py b/test/dygraph_to_static/test_cache_program.py index 199c3e980e20c9..9683afb05bdda0 100644 --- a/test/dygraph_to_static/test_cache_program.py +++ b/test/dygraph_to_static/test_cache_program.py @@ -16,7 +16,7 @@ from collections import Counter import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase from test_fetch_feed import Linear, Pool2D import paddle @@ -25,8 +25,7 @@ from paddle.jit.dy2static import convert_to_static -@dy2static_unittest -class TestCacheProgram(unittest.TestCase): +class TestCacheProgram(Dy2StTestBase): def setUp(self): self.batch_num = 5 self.dygraph_class = Pool2D @@ -76,8 +75,7 @@ def setUp(self): self.data = np.random.random((4, 10)).astype('float32') -@dy2static_unittest -class TestCacheProgramWithOptimizer(unittest.TestCase): +class TestCacheProgramWithOptimizer(Dy2StTestBase): def setUp(self): self.dygraph_class = Linear self.data = np.random.random((4, 10)).astype('float32') @@ -126,8 +124,7 @@ def simple_func(x): return mean -@dy2static_unittest -class TestConvertWithCache(unittest.TestCase): +class TestConvertWithCache(Dy2StTestBase): def test_cache(self): static_func = convert_to_static(simple_func) # Get transformed function from cache. @@ -157,8 +154,7 @@ def sum_under_while(limit): return ret_sum -@dy2static_unittest -class TestToOutputWithCache(unittest.TestCase): +class TestToOutputWithCache(Dy2StTestBase): def test_output(self): with base.dygraph.guard(): ret = sum_even_until_limit(80, 10) diff --git a/test/dygraph_to_static/test_cinn.py b/test/dygraph_to_static/test_cinn.py index 84e619149c8009..0f8f5c962934cb 100644 --- a/test/dygraph_to_static/test_cinn.py +++ b/test/dygraph_to_static/test_cinn.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -45,8 +42,7 @@ def apply_to_static(net, use_cinn): return paddle.jit.to_static(net, build_strategy=build_strategy) -@dy2static_unittest -class TestCINN(unittest.TestCase): +class TestCINN(Dy2StTestBase): def setUp(self): self.x = paddle.randn([2, 4]) self.x.stop_gradient = False @@ -83,7 +79,7 @@ def train(self, use_cinn): return res - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_cinn(self): dy_res = self.train(use_cinn=False) cinn_res = self.train(use_cinn=True) diff --git a/test/dygraph_to_static/test_cinn_prim.py b/test/dygraph_to_static/test_cinn_prim.py index 2ed5326f7b9d00..95df5d498c6fb9 100644 --- a/test/dygraph_to_static/test_cinn_prim.py +++ b/test/dygraph_to_static/test_cinn_prim.py @@ -15,10 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -43,8 +43,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestPrimForward(unittest.TestCase): +class TestPrimForward(Dy2StTestBase): """ This case only tests prim_forward + to_static + cinn. Thus we need to set this flag as False to avoid prim_backward. @@ -94,7 +93,7 @@ def check_prim(self, net, use_prim): # Ensure that softmax is splitted into small ops self.assertTrue('softmax' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim_forward(self): dy_res = self.train(use_prim=False) cinn_res = self.train(use_prim=True) @@ -105,8 +104,7 @@ def test_cinn_prim_forward(self): ) -@dy2static_unittest -class TestPrimForwardAndBackward(unittest.TestCase): +class TestPrimForwardAndBackward(Dy2StTestBase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph """ @@ -161,7 +159,7 @@ def check_prim(self, net, use_prim): if op != "matmul_v2_grad": self.assertTrue("_grad" not in op) - @ast_only_test + @test_ast_only def test_cinn_prim(self): dy_res = self.train(use_prim=False) cinn_res = self.train(use_prim=True) @@ -172,9 +170,8 @@ def test_cinn_prim(self): ) -@dy2static_unittest -class TestBackend(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestBackend(Dy2StTestBase): + @test_legacy_and_pir def test_backend(self): x = paddle.randn([2, 4]) out1 = self.forward(x, 'CINN') diff --git a/test/dygraph_to_static/test_cinn_prim_gelu.py b/test/dygraph_to_static/test_cinn_prim_gelu.py index be2e8f67c1e988..ab9b3697eba620 100644 --- a/test/dygraph_to_static/test_cinn_prim_gelu.py +++ b/test/dygraph_to_static/test_cinn_prim_gelu.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle import paddle.nn.functional as F @@ -53,8 +53,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestPrimForwardAndBackward(unittest.TestCase): +class TestPrimForwardAndBackward(Dy2StTestBase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph """ @@ -106,7 +105,7 @@ def check_prim(self, net, use_prim): # Ensure that gelu is splitted into small ops self.assertTrue('gelu' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim(self): for shape in self.shapes: for dtype in self.dtypes: diff --git a/test/dygraph_to_static/test_cinn_prim_layer_norm.py b/test/dygraph_to_static/test_cinn_prim_layer_norm.py index 42bf36d731eca6..94186bb1bff39b 100644 --- a/test/dygraph_to_static/test_cinn_prim_layer_norm.py +++ b/test/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle import paddle.nn.functional as F @@ -52,8 +52,7 @@ def forward(self, x, w, b): return out[0] -@dy2static_unittest -class TestPrimForward(unittest.TestCase): +class TestPrimForward(Dy2StTestBase): """ This case only tests prim_forward + to_static + cinn. Thus we need to set this flag as False to avoid prim_backward. @@ -103,7 +102,7 @@ def check_prim(self, net, use_prim): # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim_forward(self): for dtype in self.dtypes: if paddle.device.get_device() == "cpu": @@ -125,8 +124,7 @@ def test_cinn_prim_forward(self): ) -@dy2static_unittest -class TestPrimForwardAndBackward(unittest.TestCase): +class TestPrimForwardAndBackward(Dy2StTestBase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph """ @@ -172,7 +170,7 @@ def check_prim(self, net, use_prim): # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim(self): for dtype in self.dtypes: if paddle.device.get_device() == "cpu": diff --git a/test/dygraph_to_static/test_cinn_prim_mean.py b/test/dygraph_to_static/test_cinn_prim_mean.py index cb32f5b466035e..fe82e9cfe0a5b3 100644 --- a/test/dygraph_to_static/test_cinn_prim_mean.py +++ b/test/dygraph_to_static/test_cinn_prim_mean.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle from paddle import tensor @@ -55,8 +55,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestPrimForward(unittest.TestCase): +class TestPrimForward(Dy2StTestBase): """ This case only tests prim_forward + to_static + cinn. Thus we need to set this flag as False to avoid prim_backward. @@ -112,7 +111,7 @@ def check_prim(self, net, use_prim): # Ensure that reduce_mean is splitted into small ops self.assertTrue('reduce_mean' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim_forward(self): for shape in self.shapes: for dtype in self.dtypes: @@ -134,8 +133,7 @@ def test_cinn_prim_forward(self): ) -@dy2static_unittest -class TestPrimForwardAndBackward(unittest.TestCase): +class TestPrimForwardAndBackward(Dy2StTestBase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph """ @@ -187,7 +185,7 @@ def check_prim(self, net, use_prim): # Ensure that reduce_mean is splitted into small ops self.assertTrue('reduce_mean' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim(self): for shape in self.shapes: for dtype in self.dtypes: diff --git a/test/dygraph_to_static/test_convert_call.py b/test/dygraph_to_static/test_convert_call.py index 723d3f910debdd..bd21698579d93b 100644 --- a/test/dygraph_to_static/test_convert_call.py +++ b/test/dygraph_to_static/test_convert_call.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle import paddle.jit.dy2static as _jst @@ -77,8 +77,7 @@ def dyfunc_with_staticmethod(x_v): return a.add(x_v, x_v) -@dy2static_unittest -class TestRecursiveCall1(unittest.TestCase): +class TestRecursiveCall1(Dy2StTestBase): def setUp(self): self.input = np.random.random([10, 16]).astype('float32') self.place = ( @@ -169,8 +168,7 @@ def forward(self, inputs): return self.act(out) -@dy2static_unittest -class TestRecursiveCall2(unittest.TestCase): +class TestRecursiveCall2(Dy2StTestBase): def setUp(self): self.input = np.random.random((1, 3, 3, 5)).astype('float32') self.place = ( @@ -253,7 +251,6 @@ def test_code(self): ) -@dy2static_unittest class TestNotToConvert2(TestRecursiveCall2): def set_func(self): self.net = NotToStaticHelper() @@ -266,7 +263,7 @@ def test_conversion_options(self): self.assertIsNotNone(options) self.assertTrue(options.not_convert) - @ast_only_test + @test_ast_only def test_code(self): self.dygraph_func = paddle.jit.to_static(self.net.sum) # check 'if statement' is not converted @@ -281,23 +278,22 @@ def forward(self, x): return x -@dy2static_unittest -class TestConvertPaddleAPI(unittest.TestCase): - @ast_only_test +class TestConvertPaddleAPI(Dy2StTestBase): + @test_ast_only def test_functional_api(self): func = paddle.nn.functional.relu func = paddle.jit.to_static(func) self.assertNotIn("_jst.IfElse", func.code) self.assertIn("if in_dynamic_or_pir_mode()", func.code) - @ast_only_test + @test_ast_only def test_class_api(self): bn = paddle.nn.SyncBatchNorm(2) paddle.jit.to_static(bn) self.assertNotIn("_jst.IfElse", bn.forward.code) self.assertIn("if in_dynamic_mode()", bn.forward.code) - @ast_only_test + @test_ast_only def test_class_patch_api(self): paddle.nn.SyncBatchNorm.forward = forward bn = paddle.nn.SyncBatchNorm(2) diff --git a/test/dygraph_to_static/test_convert_call_generator.py b/test/dygraph_to_static/test_convert_call_generator.py index dd9d93c907c552..b3793fa22d289c 100644 --- a/test/dygraph_to_static/test_convert_call_generator.py +++ b/test/dygraph_to_static/test_convert_call_generator.py @@ -14,10 +14,10 @@ import unittest -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -36,11 +36,10 @@ def main_func(): print(i) -@dy2static_unittest -class TestConvertGenerator(unittest.TestCase): +class TestConvertGenerator(Dy2StTestBase): # fallback will ok. - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_raise_error(self): translator_logger.verbosity_level = 1 with self.assertLogs( diff --git a/test/dygraph_to_static/test_convert_operators.py b/test/dygraph_to_static/test_convert_operators.py index 02d0c09a70857c..05a6d4de9c7d9f 100644 --- a/test/dygraph_to_static/test_convert_operators.py +++ b/test/dygraph_to_static/test_convert_operators.py @@ -15,10 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -44,10 +44,9 @@ def forward(self): net.forward = "A string so that convert forward will fail" -@dy2static_unittest -class TestConvertCall(unittest.TestCase): +class TestConvertCall(Dy2StTestBase): # fallback mode will raise a InnerError, it's ok. - @ast_only_test + @test_ast_only def test_class_exception(self): @paddle.jit.to_static def call_not_exist(): @@ -73,8 +72,7 @@ def callable_list(x, y): self.assertEqual(callable_list(1, 2), 3) -@dy2static_unittest -class TestConvertShapeCompare(unittest.TestCase): +class TestConvertShapeCompare(Dy2StTestBase): def test_non_variable(self): self.assertEqual( paddle.jit.dy2static.convert_shape_compare(1, "<", 2), True @@ -136,7 +134,7 @@ def error_func(): False, ) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_variable(self): paddle.enable_static() with paddle.static.program_guard( @@ -210,9 +208,8 @@ def forward(self, x): return out -@dy2static_unittest -class TestChooseShapeAttrOrApiWithLayer(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestChooseShapeAttrOrApiWithLayer(Dy2StTestBase): + @test_legacy_and_pir def test_tensor_shape(self): x = paddle.zeros(shape=[4, 1], dtype='float32') net = ShapeLayer() @@ -221,9 +218,8 @@ def test_tensor_shape(self): np.testing.assert_array_equal(out.numpy(), x.numpy()) -@dy2static_unittest -class TestIfElseNoValue(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestIfElseNoValue(Dy2StTestBase): + @test_legacy_and_pir def test_else_ret_none(self): input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) @@ -253,7 +249,7 @@ def without_common_value(x, use_cache=False): out = without_common_value(input_x, False) self.assertIsNone(out) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_else_ret_c(self): input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) @@ -286,7 +282,7 @@ def without_common_value(x, use_cache=False): self.assertListEqual(paddle.tolist(y), paddle.tolist(input_x + 1)) self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x + 2)) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_else_ret_cz(self): input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) diff --git a/test/dygraph_to_static/test_cpu_cuda_to_tensor.py b/test/dygraph_to_static/test_cpu_cuda_to_tensor.py index b6e55b8900c1e8..1d199dc8138df1 100644 --- a/test/dygraph_to_static/test_cpu_cuda_to_tensor.py +++ b/test/dygraph_to_static/test_cpu_cuda_to_tensor.py @@ -15,18 +15,16 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - sot_only_test, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle -@dy2static_unittest -class TestCpuCuda(unittest.TestCase): +class TestCpuCuda(Dy2StTestBase): def test_cpu_cuda(self): def func(x): x = paddle.to_tensor([1, 2, 3, 4]) @@ -39,9 +37,8 @@ def func(x): # print(paddle.jit.to_static(func)(x)) -@dy2static_unittest -class TestToTensor(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestToTensor(Dy2StTestBase): + @test_legacy_and_pir def test_to_tensor_with_variable_list(self): def func(x): ones = paddle.to_tensor(1) @@ -58,10 +55,9 @@ def func(x): ) -@dy2static_unittest -class TestToTensor1(unittest.TestCase): - @ast_only_test - @test_and_compare_with_new_ir(False) +class TestToTensor1(Dy2StTestBase): + @test_ast_only + @test_legacy_and_pir def test_to_tensor_with_variable_list(self): def func(x): ones = paddle.to_tensor([1]) @@ -79,8 +75,8 @@ def func(x): rtol=1e-05, ) - @sot_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_to_tensor_with_variable_list_sot(self): def func(x): ones = paddle.to_tensor([1]) @@ -99,10 +95,9 @@ def func(x): ) -@dy2static_unittest -class TestToTensor2(unittest.TestCase): - @ast_only_test - @test_and_compare_with_new_ir(False) +class TestToTensor2(Dy2StTestBase): + @test_ast_only + @test_legacy_and_pir def test_to_tensor_with_variable_list(self): def func(x): x = paddle.to_tensor([[1], [2], [3], [4]]) @@ -115,8 +110,8 @@ def func(x): rtol=1e-05, ) - @sot_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_to_tensor_with_variable_list_sot(self): def func(x): x = paddle.to_tensor([[1], [2], [3], [4]]) diff --git a/test/dygraph_to_static/test_cycle_gan.py b/test/dygraph_to_static/test_cycle_gan.py index fb06a52407ec61..d069a630b73fe1 100644 --- a/test/dygraph_to_static/test_cycle_gan.py +++ b/test/dygraph_to_static/test_cycle_gan.py @@ -26,10 +26,7 @@ # Use GPU:0 to elimate the influence of other tasks. os.environ["CUDA_VISIBLE_DEVICES"] = "1" -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle from paddle.base.dygraph import to_variable @@ -679,8 +676,7 @@ def train(args, to_static): return np.array(loss_data) -@dy2static_unittest -class TestCycleGANModel(unittest.TestCase): +class TestCycleGANModel(Dy2StTestBase): def setUp(self): self.args = Args() @@ -688,7 +684,7 @@ def train(self, to_static): out = train(self.args, to_static) return out - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_train(self): st_out = self.train(to_static=True) dy_out = self.train(to_static=False) diff --git a/test/dygraph_to_static/test_dict.py b/test/dygraph_to_static/test_dict.py index 99364c1343a7d6..c88496fd86b3e1 100644 --- a/test/dygraph_to_static/test_dict.py +++ b/test/dygraph_to_static/test_dict.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import base @@ -119,8 +116,7 @@ def update_cache(cache): return cache -@dy2static_unittest -class TestNetWithDict(unittest.TestCase): +class TestNetWithDict(Dy2StTestBase): """ TestCase for the transformation from control flow `if/else` dependent on tensor in Dygraph into Static `base.layers.cond`. @@ -130,7 +126,7 @@ def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.batch_size = self.x.shape[0] - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def _run_static(self): return self.train(to_static=True) @@ -173,8 +169,7 @@ def test_dic_pop_2(x): return out -@dy2static_unittest -class TestDictPop(unittest.TestCase): +class TestDictPop(Dy2StTestBase): def setUp(self): self.input = np.random.random(3).astype('int32') self.place = ( @@ -187,7 +182,7 @@ def setUp(self): def _set_test_func(self): self.dygraph_func = test_dic_pop - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def _run_static(self): return self._run(to_static=True) @@ -254,8 +249,7 @@ def test_ast_to_func(self): ) -@dy2static_unittest -class TestDictCmpInFor(unittest.TestCase): +class TestDictCmpInFor(Dy2StTestBase): def test_with_for(self): def func(): pos = [1, 3] diff --git a/test/dygraph_to_static/test_drop_path.py b/test/dygraph_to_static/test_drop_path.py index aad752007ceb0c..d559ce7f55ac29 100644 --- a/test/dygraph_to_static/test_drop_path.py +++ b/test/dygraph_to_static/test_drop_path.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -39,15 +36,14 @@ def forward(self, x): return drop_path(x, self.training) -@dy2static_unittest -class TestTrainEval(unittest.TestCase): +class TestTrainEval(Dy2StTestBase): def setUp(self): self.model = DropPath() def tearDown(self): pass - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_train_and_eval(self): x = paddle.to_tensor([1, 2, 3]).astype("int64") eval_out = x.numpy() diff --git a/test/dygraph_to_static/test_duplicate_output.py b/test/dygraph_to_static/test_duplicate_output.py index c7f1e21b3552ab..70637729671f0b 100644 --- a/test/dygraph_to_static/test_duplicate_output.py +++ b/test/dygraph_to_static/test_duplicate_output.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -41,8 +38,7 @@ def forward(self, x): return x, x -@dy2static_unittest -class TestDuplicateOutput(unittest.TestCase): +class TestDuplicateOutput(Dy2StTestBase): """ TestCase for the transformation from control flow `if/else` dependent on tensor in Dygraph into Static `base.layers.cond`. @@ -52,7 +48,7 @@ def setUp(self): self.net = paddle.jit.to_static(SimpleNet()) self.x = paddle.to_tensor([1.0]) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def _run_static(self): param = self.net.parameters() param[0].clear_grad()