Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【pir】 modify test_Gradname_parse and warprnnt optest bug #59215

Merged
merged 28 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bce9b3b
tmp
xiaoguoguo626807 Aug 30, 2023
c2341a5
fix conflict
xiaoguoguo626807 Aug 30, 2023
4d30fdd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 31, 2023
cae7604
modify ci bug
xiaoguoguo626807 Sep 19, 2023
c94252d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 19, 2023
305ed20
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 20, 2023
3aa6686
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 22, 2023
6c553e6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 22, 2023
3b3b5ea
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 25, 2023
7e8e095
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 25, 2023
9c09a56
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Oct 7, 2023
cae57c1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Oct 7, 2023
d52fe87
[PIR]Migrate maximum into pir
Oct 8, 2023
9e5a0b1
Polish code
Oct 9, 2023
2218be2
add ir_grad of static_gradient
xiaoguoguo626807 Oct 9, 2023
b190b2f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Oct 9, 2023
2ce9d92
Merge commit 'refs/pull/57929/head' of https://github.com/PaddlePaddl…
xiaoguoguo626807 Oct 9, 2023
02040b1
add test
xiaoguoguo626807 Oct 9, 2023
ae9b38a
Merge branch 'develop', commit 'refs/pull/57956/head' of https://gith…
xiaoguoguo626807 Oct 9, 2023
464106f
tmp
xiaoguoguo626807 Nov 16, 2023
e8421b1
modify backward
xiaoguoguo626807 Nov 17, 2023
ff2bcf2
modify
xiaoguoguo626807 Nov 17, 2023
30521e5
modify segment
xiaoguoguo626807 Nov 17, 2023
5d5645a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Nov 20, 2023
a39ccff
modify warprnnt
xiaoguoguo626807 Nov 21, 2023
6b54baa
fix pir error 34
2742195759 Nov 21, 2023
78e2172
Merge branch 'develop', commit 'refs/pull/59204/head' of https://gith…
xiaoguoguo626807 Nov 21, 2023
742d1bd
modofy test_gradname_parse
xiaoguoguo626807 Nov 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ pir::OpResult FakeOpResult() {

bool IsFakeOpResult(const pir::OpResult &result) {
// create a fake opresults to simplify `ForwardBackwardSplit`.
return result.Value::impl() == nullptr;
return result.Value::impl() == nullptr || !result.Value::type();
}

static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
Expand Down
37 changes: 21 additions & 16 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,25 +135,30 @@ def prepare_grad_outputs(grad_outputs, outputs, state):
visited_output.add(opresult)
continue
else:
grad_value = paddle.full_like(
opresult,
0.0,
opresult.dtype,
)
full_likeop = grad_value.get_defining_op()
fullop = full_likeop.operand_source(1).get_defining_op()
if paddle.pir.is_fake_op_result(opresult):
state.value_to_valuegrad[opresult] = [
[paddle.pir.fake_op_result()]
]
else:
grad_value = paddle.full_like(
opresult,
0.0,
opresult.dtype,
)
full_likeop = grad_value.get_defining_op()
fullop = full_likeop.operand_source(1).get_defining_op()

update_bwdop_structure(
backward_ops,
state.op_to_opgrad[opresult.get_defining_op()],
[full_likeop, fullop],
)
state.value_to_valuegrad[opresult] = [[grad_value]]
update_bwdop_structure(
backward_ops,
state.op_to_opgrad[opresult.get_defining_op()],
[full_likeop, fullop],
)
state.value_to_valuegrad[opresult] = [[grad_value]]

visited_output.add(opresult)
visited_output.add(opresult)

complete_outputs.append(opresult)
complete_gradoutputs.append(grad_value)
complete_outputs.append(opresult)
complete_gradoutputs.append(grad_value)

return complete_outputs, complete_gradoutputs, backward_ops

Expand Down
21 changes: 15 additions & 6 deletions test/dygraph_to_static/test_gradname_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
import unittest

import numpy as np
from dygraph_to_static_utils_new import Dy2StTestBase
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_ast_only,
test_pir_api_only,
)

import paddle
from paddle.nn import BatchNorm, Linear
Expand Down Expand Up @@ -82,18 +86,23 @@ def setUp(self):
self.dy2st_input = (x2,)
self.dy2st_grad_input = (x2,)

@test_ast_only
@test_pir_api_only
def test_run(self):
try:
dy_out = self.func(*self.dy_input)
dy_grad = paddle.grad(dy_out, self.dy_grad_input)
dy_grad = paddle.grad(dy_out, self.dy_grad_input, allow_unused=True)
except:
dy_grad = [None for i in self.dy_grad_input]
dy_grad = [
t.numpy() if isinstance(t, paddle.Tensor) else t for t in dy_grad
]

dy2st_out = paddle.jit.to_static(self.func)(*self.dy2st_input)
dy2st_grad = paddle.grad(dy2st_out, self.dy2st_grad_input)
tmp_func = paddle.jit.to_static(self.func, full_graph=True)
dy2st_out = tmp_func(*self.dy2st_input)
dy2st_grad = paddle.grad(
dy2st_out, self.dy2st_grad_input, allow_unused=True
)
dy2st_grad = [
t.numpy() if isinstance(t, paddle.Tensor) else t for t in dy_grad
]
Expand All @@ -112,8 +121,8 @@ def test_run(self):

def matmul_high_order_grad(x, y):
z = paddle.matmul(x, y)
g = paddle.grad(z, [x, y], create_graph=True)
return g[0]
g = paddle.grad(z, [x], create_graph=True, allow_unused=True)
return g


class TestMatMulHighOrderGrad1(TestTanhHighOrderGrad):
Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3676,7 +3676,7 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
)
fetch_list = list(grad_inputs)
# executor run
executor = paddle.static.Executor()
executor = paddle.static.Executor(place)
outs = executor.run(
ir_program,
feed=feed,
Expand Down
6 changes: 2 additions & 4 deletions test/legacy_test/test_warprnnt_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.outputs["warprnntgrad"] = self.gradient
Expand All @@ -239,9 +239,7 @@ def test_check_grad(self):
)
else:
self.check_grad(
["input"],
"loss",
numeric_grad_delta=0.009,
["input"], "loss", numeric_grad_delta=0.009, check_pir=True
)


Expand Down