diff --git a/oneflow/api/python/autograd/autograd_mode.cpp b/oneflow/api/python/autograd/autograd_mode.cpp index 975e3f87c9c..eaebb6f73a9 100644 --- a/oneflow/api/python/autograd/autograd_mode.cpp +++ b/oneflow/api/python/autograd/autograd_mode.cpp @@ -32,6 +32,7 @@ ONEFLOW_API_PYBIND11_MODULE("autograd", m) { .def("__exit__", [](const AutoGradMode& no_grad_obj, const py::object& type, const py::object& value, const py::object& traceback) {}); m.def("is_grad_enabled", &GradMode::is_enabled); + m.def("set_grad_enabled", &GradMode::set_enabled); } } // namespace autograd diff --git a/python/oneflow/autograd/autograd_mode.py b/python/oneflow/autograd/autograd_mode.py index 747134d38c4..69e1af2db5f 100644 --- a/python/oneflow/autograd/autograd_mode.py +++ b/python/oneflow/autograd/autograd_mode.py @@ -196,8 +196,13 @@ class set_grad_enabled: def __init__(self, is_train=True): self.is_train = is_train + self.prev_mode = is_grad_enabled() + oneflow._oneflow_internal.autograd.set_grad_enabled(is_train) def __call__(self, func): + # recover grad mode set in __init__ + oneflow._oneflow_internal.autograd.set_grad_enabled(self.prev_mode) + def wrapper(*args, **kwargs): with AutoGradMode(self.is_train): return func(*args, **kwargs) @@ -205,6 +210,8 @@ def wrapper(*args, **kwargs): return wrapper def __enter__(self): + # recover grad mode set in __init__ + oneflow._oneflow_internal.autograd.set_grad_enabled(self.prev_mode) self.grad_mode = AutoGradMode(self.is_train) return self diff --git a/python/oneflow/test/modules/test_autograd_mode.py b/python/oneflow/test/modules/test_autograd_mode.py index ee8abed9285..857d96bfdea 100644 --- a/python/oneflow/test/modules/test_autograd_mode.py +++ b/python/oneflow/test/modules/test_autograd_mode.py @@ -72,27 +72,55 @@ def func(): test_case.assertTrue(flow.is_grad_enabled()) def test_set_grad_enabled(test_case): - with flow.set_grad_enabled(True): - test_case.assertTrue(flow.is_grad_enabled()) - test_case.assertTrue(flow.is_grad_enabled()) + def assert_grad_mode(mode): + if mode: + test_case.assertTrue(flow.is_grad_enabled()) + else: + test_case.assertFalse(flow.is_grad_enabled()) - @flow.set_grad_enabled(True) - def func(): - test_case.assertTrue(flow.is_grad_enabled()) + def get_decorater_func_with_mode(mode): + @flow.set_grad_enabled(mode) + def func(): + assert_grad_mode(mode) - func() - test_case.assertTrue(flow.is_grad_enabled()) + return func - with flow.set_grad_enabled(False): - test_case.assertFalse(flow.is_grad_enabled()) - test_case.assertTrue(flow.is_grad_enabled()) + def get_decorater_context_func_with_mode(dec_mode, ctx_mode): + @flow.set_grad_enabled(dec_mode) + def func(): + assert_grad_mode(dec_mode) + with flow.set_grad_enabled(ctx_mode): + assert_grad_mode(ctx_mode) + assert_grad_mode(dec_mode) - @flow.set_grad_enabled(False) - def func(): - test_case.assertFalse(flow.is_grad_enabled()) + return func - func() - test_case.assertTrue(flow.is_grad_enabled()) + flow.set_grad_enabled(False) + assert_grad_mode(False) + + with flow.set_grad_enabled(True): + assert_grad_mode(True) + flow.set_grad_enabled(False) + assert_grad_mode(False) + func = get_decorater_func_with_mode(True) + func() + assert_grad_mode(False) + + flow.set_grad_enabled(True) + assert_grad_mode(True) + + with flow.set_grad_enabled(False): + assert_grad_mode(False) + flow.set_grad_enabled(True) + assert_grad_mode(True) + func = get_decorater_func_with_mode(False) + func() + assert_grad_mode(True) + + get_decorater_context_func_with_mode(True, True)() + get_decorater_context_func_with_mode(True, False)() + get_decorater_context_func_with_mode(False, True)() + get_decorater_context_func_with_mode(False, False)() if __name__ == "__main__":