Skip to content

Commit

Permalink
【pir】 modify test_Gradname_parse and warprnnt optest bug (PaddlePaddl…
Browse files Browse the repository at this point in the history
…e#59215)

* tmp

* modify ci bug

* [PIR]Migrate maximum into pir

* Polish code

* add ir_grad of static_gradient

* add test

* modify backward

* modify

* modify segment

* modify warprnnt

* fix pir error 34

* modofy test_gradname_parse

---------

Co-authored-by: 0x45f <wangzhen45@baidu.com>
Co-authored-by: xiongkun <xiongkun03@baidu.com>
  • Loading branch information
3 people authored and SecretXV committed Nov 28, 2023
1 parent f2baea5 commit b01910f
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 28 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ pir::OpResult FakeOpResult() {

bool IsFakeOpResult(const pir::OpResult &result) {
// create a fake opresults to simplify `ForwardBackwardSplit`.
return result.Value::impl() == nullptr;
return result.Value::impl() == nullptr || !result.Value::type();
}

static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
Expand Down
37 changes: 21 additions & 16 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,25 +135,30 @@ def prepare_grad_outputs(grad_outputs, outputs, state):
visited_output.add(opresult)
continue
else:
grad_value = paddle.full_like(
opresult,
0.0,
opresult.dtype,
)
full_likeop = grad_value.get_defining_op()
fullop = full_likeop.operand_source(1).get_defining_op()
if paddle.pir.is_fake_op_result(opresult):
state.value_to_valuegrad[opresult] = [
[paddle.pir.fake_op_result()]
]
else:
grad_value = paddle.full_like(
opresult,
0.0,
opresult.dtype,
)
full_likeop = grad_value.get_defining_op()
fullop = full_likeop.operand_source(1).get_defining_op()

update_bwdop_structure(
backward_ops,
state.op_to_opgrad[opresult.get_defining_op()],
[full_likeop, fullop],
)
state.value_to_valuegrad[opresult] = [[grad_value]]
update_bwdop_structure(
backward_ops,
state.op_to_opgrad[opresult.get_defining_op()],
[full_likeop, fullop],
)
state.value_to_valuegrad[opresult] = [[grad_value]]

visited_output.add(opresult)
visited_output.add(opresult)

complete_outputs.append(opresult)
complete_gradoutputs.append(grad_value)
complete_outputs.append(opresult)
complete_gradoutputs.append(grad_value)

return complete_outputs, complete_gradoutputs, backward_ops

Expand Down
21 changes: 15 additions & 6 deletions test/dygraph_to_static/test_gradname_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
import unittest

import numpy as np
from dygraph_to_static_utils_new import Dy2StTestBase
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_ast_only,
test_pir_api_only,
)

import paddle
from paddle.nn import BatchNorm, Linear
Expand Down Expand Up @@ -82,18 +86,23 @@ def setUp(self):
self.dy2st_input = (x2,)
self.dy2st_grad_input = (x2,)

@test_ast_only
@test_pir_api_only
def test_run(self):
try:
dy_out = self.func(*self.dy_input)
dy_grad = paddle.grad(dy_out, self.dy_grad_input)
dy_grad = paddle.grad(dy_out, self.dy_grad_input, allow_unused=True)
except:
dy_grad = [None for i in self.dy_grad_input]
dy_grad = [
t.numpy() if isinstance(t, paddle.Tensor) else t for t in dy_grad
]

dy2st_out = paddle.jit.to_static(self.func)(*self.dy2st_input)
dy2st_grad = paddle.grad(dy2st_out, self.dy2st_grad_input)
tmp_func = paddle.jit.to_static(self.func, full_graph=True)
dy2st_out = tmp_func(*self.dy2st_input)
dy2st_grad = paddle.grad(
dy2st_out, self.dy2st_grad_input, allow_unused=True
)
dy2st_grad = [
t.numpy() if isinstance(t, paddle.Tensor) else t for t in dy_grad
]
Expand All @@ -112,8 +121,8 @@ def test_run(self):

def matmul_high_order_grad(x, y):
z = paddle.matmul(x, y)
g = paddle.grad(z, [x, y], create_graph=True)
return g[0]
g = paddle.grad(z, [x], create_graph=True, allow_unused=True)
return g


class TestMatMulHighOrderGrad1(TestTanhHighOrderGrad):
Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3676,7 +3676,7 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
)
fetch_list = list(grad_inputs)
# executor run
executor = paddle.static.Executor()
executor = paddle.static.Executor(place)
outs = executor.run(
ir_program,
feed=feed,
Expand Down
6 changes: 2 additions & 4 deletions test/legacy_test/test_warprnnt_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.outputs["warprnntgrad"] = self.gradient
Expand All @@ -239,9 +239,7 @@ def test_check_grad(self):
)
else:
self.check_grad(
["input"],
"loss",
numeric_grad_delta=0.009,
["input"], "loss", numeric_grad_delta=0.009, check_pir=True
)


Expand Down

0 comments on commit b01910f

Please sign in to comment.