-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PIR API adaptor No.90,92】Migrate some ops into pir #59801
Changes from 3 commits
bfff5e2
45649b2
8b7efb4
97fed79
d17add9
70e63fb
41c6031
8c67fb7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -350,7 +350,7 @@ def set_data(self): | |||||
|
||||||
def test_check_output(self): | ||||||
# NODE(yjjiang11): This op will be deprecated. | ||||||
self.check_output(check_dygraph=False) | ||||||
self.check_output(check_dygraph=False, check_pir=True) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
当前 |
||||||
|
||||||
def setUp(self): | ||||||
self.op_type = "generate_proposals" | ||||||
|
This comment was marked as resolved.
Sorry, something went wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个应该如何进行定位呢?一直不知道如何进行定位 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这种继承自OpTest 的单测一般用 pdb 等工具在python 端进行单步调试即可。比如这里我是单步调试看运行时ret_tuple和outputs_sig分别是什么值 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 懂了,我目前都是CI Debug哈哈 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
import numpy as np | ||
|
||
import paddle | ||
from paddle.pir_utils import test_with_pir_api | ||
|
||
|
||
class TestGraphReindex(unittest.TestCase): | ||
|
@@ -128,6 +129,7 @@ def test_heter_reindex_result_v2(self): | |
np.testing.assert_allclose(reindex_dst, reindex_dst_, rtol=1e-05) | ||
np.testing.assert_allclose(out_nodes, out_nodes_, rtol=1e-05) | ||
|
||
@test_with_pir_api | ||
def test_reindex_result_static(self): | ||
paddle.enable_static() | ||
with paddle.static.program_guard(paddle.static.Program()): | ||
|
@@ -369,6 +371,7 @@ def test_heter_reindex_result_v3(self): | |
np.testing.assert_allclose(reindex_dst, reindex_dst_, rtol=1e-05) | ||
np.testing.assert_allclose(out_nodes, out_nodes_, rtol=1e-05) | ||
|
||
@test_with_pir_api | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 当前 pr 适配的是 paddle.incubate.graph_reindex 而非 paddle.geometric.reindex_graph。这个单测取消掉吧 |
||
def test_reindex_result_static(self): | ||
paddle.enable_static() | ||
with paddle.static.program_guard(paddle.static.Program()): | ||
|
@@ -448,6 +451,7 @@ def test_reindex_result_static(self): | |
) | ||
np.testing.assert_allclose(self.out_nodes, out_nodes_2, rtol=1e-05) | ||
|
||
@test_with_pir_api | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
def test_heter_reindex_result_static(self): | ||
paddle.enable_static() | ||
np_x = np.arange(5).astype("int64") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里动态图和下面的静态图的逻辑并不统一。静态图还有设置 stop_gradient 的过程,建议新开一个 in_pir_mode 分支,添加设置 stop_gradient 的逻辑