Skip to content

Commit

Permalink
fix exponential op (PaddlePaddle#58029)
Browse files Browse the repository at this point in the history
* fix exponential op

* 更新 op_compat.yaml

* fix parser

* 更新 utils.cc

* fix_exponential_op

* add flag FLAGS_FLAGS_NEW_IR_NO_CHECK

* add flag FLAGS_FLAGS_NEW_IR_NO_CHECK

* add flag FLAGS_FLAGS_NEW_IR_NO_CHECK

* add flag FLAGS_FLAGS_NEW_IR_NO_CHECK

* add flag FLAGS_FLAGS_NEW_IR_NO_CHECK

* Update new_ir_op_test_no_check_list

* Update op_test.py

* Update op_test.py

* Update CMakeLists.txt

* fix getenv
  • Loading branch information
xingmingyyj authored Oct 26, 2023
1 parent 9b3ac67 commit 87b2e24
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 4 deletions.
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -987,8 +987,8 @@
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]

- op : exponential_
backward : exponential__grad
- op : exponential_ (exponential)
backward : exponential__grad (exponential_grad)
inputs :
x : X
outputs :
Expand Down
11 changes: 11 additions & 0 deletions test/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,17 @@ foreach(IR_OP_TEST ${NEW_IR_OP_TESTS})
endif()
endforeach()

file(STRINGS "${CMAKE_SOURCE_DIR}/test/white_list/new_ir_op_test_no_check_list"
NEW_IR_OP_NO_CHECK_TESTS)
foreach(IR_OP_TEST ${NEW_IR_OP_NO_CHECK_TESTS})
if(TEST ${IR_OP_TEST})
set_tests_properties(${IR_OP_TEST} PROPERTIES ENVIRONMENT
"FLAGS_NEW_IR_NO_CHECK=True")
else()
message(STATUS "NewIR OpTest: not found ${IR_OP_TEST} in legacy_test")
endif()
endforeach()

file(STRINGS
"${CMAKE_SOURCE_DIR}/test/white_list/new_ir_op_test_precision_white_list"
NEW_IR_OP_RELAXED_TESTS)
Expand Down
9 changes: 7 additions & 2 deletions test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,10 +1435,12 @@ def _check_ir_output(self, place, program, feed_map, fetch_list, outs):
), "Fetch result should have same length when executed in pir"

check_method = np.testing.assert_array_equal
if os.getenv("FLAGS_NEW_IR_OPTEST_RELAX_CHECK", None):
if os.getenv("FLAGS_NEW_IR_OPTEST_RELAX_CHECK", None) == "True":
check_method = lambda x, y, z: np.testing.assert_allclose(
x, y, err_msg=z, atol=1e-6, rtol=1e-6
)
if os.getenv("FLAGS_NEW_IR_NO_CHECK", None) == "True":
check_method = lambda x, y, err_msg: None

for i in range(len(outs)):
check_method(
Expand Down Expand Up @@ -3368,11 +3370,14 @@ def _check_ir_grad_output(
)

check_method = np.testing.assert_array_equal
if os.getenv("FLAGS_NEW_IR_OPTEST_RELAX_CHECK", None):
if os.getenv("FLAGS_NEW_IR_OPTEST_RELAX_CHECK", None) == "True":
check_method = lambda x, y, z: np.testing.assert_allclose(
x, y, err_msg=z, atol=1e-6, rtol=1e-6
)

if os.getenv("FLAGS_NEW_IR_NO_CHECK", None) == "True":
check_method = lambda x, y, err_msg: None

for i in range(len(new_gradients)):
check_method(
gradients[i],
Expand Down
1 change: 1 addition & 0 deletions test/white_list/new_ir_op_test_no_check_list
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
test_exponential_op
1 change: 1 addition & 0 deletions test/white_list/new_ir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ test_elementwise_mul_op
test_elementwise_pow_op
test_erfinv_op
test_expand_v2_op
test_exponential_op
test_eye_op
test_fill_any_op
test_fill_constant_batch_size_like
Expand Down

0 comments on commit 87b2e24

Please sign in to comment.