From b01910fee4a2c0ba62b85b6e2ba8bea54d9e5874 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Wed, 22 Nov 2023 14:02:55 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90pir=E3=80=91=20modify=20=20test=5FGrad?= =?UTF-8?q?name=5Fparse=20and=20warprnnt=20optest=20bug=20(#59215)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tmp * modify ci bug * [PIR]Migrate maximum into pir * Polish code * add ir_grad of static_gradient * add test * modify backward * modify * modify segment * modify warprnnt * fix pir error 34 * modofy test_gradname_parse --------- Co-authored-by: 0x45f Co-authored-by: xiongkun --- paddle/fluid/pybind/pir.cc | 2 +- python/paddle/autograd/ir_backward.py | 37 +++++++++++-------- test/dygraph_to_static/test_gradname_parse.py | 21 ++++++++--- test/legacy_test/op_test.py | 2 +- test/legacy_test/test_warprnnt_op.py | 6 +-- 5 files changed, 40 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 818db35e55e41f..52aa94c1de0386 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -986,7 +986,7 @@ pir::OpResult FakeOpResult() { bool IsFakeOpResult(const pir::OpResult &result) { // create a fake opresults to simplify `ForwardBackwardSplit`. - return result.Value::impl() == nullptr; + return result.Value::impl() == nullptr || !result.Value::type(); } static auto GetNoNeedBufferValue(const ::pir::Block *whole_block, diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 8e112012599b81..f2ad53b0254d88 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -135,25 +135,30 @@ def prepare_grad_outputs(grad_outputs, outputs, state): visited_output.add(opresult) continue else: - grad_value = paddle.full_like( - opresult, - 0.0, - opresult.dtype, - ) - full_likeop = grad_value.get_defining_op() - fullop = full_likeop.operand_source(1).get_defining_op() + if paddle.pir.is_fake_op_result(opresult): + state.value_to_valuegrad[opresult] = [ + [paddle.pir.fake_op_result()] + ] + else: + grad_value = paddle.full_like( + opresult, + 0.0, + opresult.dtype, + ) + full_likeop = grad_value.get_defining_op() + fullop = full_likeop.operand_source(1).get_defining_op() - update_bwdop_structure( - backward_ops, - state.op_to_opgrad[opresult.get_defining_op()], - [full_likeop, fullop], - ) - state.value_to_valuegrad[opresult] = [[grad_value]] + update_bwdop_structure( + backward_ops, + state.op_to_opgrad[opresult.get_defining_op()], + [full_likeop, fullop], + ) + state.value_to_valuegrad[opresult] = [[grad_value]] - visited_output.add(opresult) + visited_output.add(opresult) - complete_outputs.append(opresult) - complete_gradoutputs.append(grad_value) + complete_outputs.append(opresult) + complete_gradoutputs.append(grad_value) return complete_outputs, complete_gradoutputs, backward_ops diff --git a/test/dygraph_to_static/test_gradname_parse.py b/test/dygraph_to_static/test_gradname_parse.py index 7b46961207af42..e15320fdc84880 100644 --- a/test/dygraph_to_static/test_gradname_parse.py +++ b/test/dygraph_to_static/test_gradname_parse.py @@ -16,7 +16,11 @@ import unittest import numpy as np -from dygraph_to_static_utils_new import Dy2StTestBase +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_pir_api_only, +) import paddle from paddle.nn import BatchNorm, Linear @@ -82,18 +86,23 @@ def setUp(self): self.dy2st_input = (x2,) self.dy2st_grad_input = (x2,) + @test_ast_only + @test_pir_api_only def test_run(self): try: dy_out = self.func(*self.dy_input) - dy_grad = paddle.grad(dy_out, self.dy_grad_input) + dy_grad = paddle.grad(dy_out, self.dy_grad_input, allow_unused=True) except: dy_grad = [None for i in self.dy_grad_input] dy_grad = [ t.numpy() if isinstance(t, paddle.Tensor) else t for t in dy_grad ] - dy2st_out = paddle.jit.to_static(self.func)(*self.dy2st_input) - dy2st_grad = paddle.grad(dy2st_out, self.dy2st_grad_input) + tmp_func = paddle.jit.to_static(self.func, full_graph=True) + dy2st_out = tmp_func(*self.dy2st_input) + dy2st_grad = paddle.grad( + dy2st_out, self.dy2st_grad_input, allow_unused=True + ) dy2st_grad = [ t.numpy() if isinstance(t, paddle.Tensor) else t for t in dy_grad ] @@ -112,8 +121,8 @@ def test_run(self): def matmul_high_order_grad(x, y): z = paddle.matmul(x, y) - g = paddle.grad(z, [x, y], create_graph=True) - return g[0] + g = paddle.grad(z, [x], create_graph=True, allow_unused=True) + return g class TestMatMulHighOrderGrad1(TestTanhHighOrderGrad): diff --git a/test/legacy_test/op_test.py b/test/legacy_test/op_test.py index 7862c0efef9843..150e286a5276e6 100644 --- a/test/legacy_test/op_test.py +++ b/test/legacy_test/op_test.py @@ -3676,7 +3676,7 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): ) fetch_list = list(grad_inputs) # executor run - executor = paddle.static.Executor() + executor = paddle.static.Executor(place) outs = executor.run( ir_program, feed=feed, diff --git a/test/legacy_test/test_warprnnt_op.py b/test/legacy_test/test_warprnnt_op.py index ced735b4310aba..9a1d8b56d7ea11 100644 --- a/test/legacy_test/test_warprnnt_op.py +++ b/test/legacy_test/test_warprnnt_op.py @@ -227,7 +227,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.outputs["warprnntgrad"] = self.gradient @@ -239,9 +239,7 @@ def test_check_grad(self): ) else: self.check_grad( - ["input"], - "loss", - numeric_grad_delta=0.009, + ["input"], "loss", numeric_grad_delta=0.009, check_pir=True )