Skip to content

Commit

Permalink
fix flow.set_grad_mode when directly calling (#10059)
Browse files Browse the repository at this point in the history
#close
Oneflow-Inc/OneCloud#203 (comment)

原先这里和 torch 的实现不同,我们是用 AutoGradMode 这个 C++ 对象的 RAII 来实现更改 grad mode 的,而
torch 是显式地调用 set_grad_mode 来更改的。这就导致了 OneFlow 里面没法全局修改线程里面的 grad
mode,只能在装饰器或者上下文语句里面修改。

---------

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 30, 2023
1 parent 4e98f8c commit 37ec204
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
1 change: 1 addition & 0 deletions oneflow/api/python/autograd/autograd_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ ONEFLOW_API_PYBIND11_MODULE("autograd", m) {
.def("__exit__", [](const AutoGradMode& no_grad_obj, const py::object& type,
const py::object& value, const py::object& traceback) {});
m.def("is_grad_enabled", &GradMode::is_enabled);
m.def("set_grad_enabled", &GradMode::set_enabled);
}

} // namespace autograd
Expand Down
7 changes: 7 additions & 0 deletions python/oneflow/autograd/autograd_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,22 @@ class set_grad_enabled:

def __init__(self, is_train=True):
self.is_train = is_train
self.prev_mode = is_grad_enabled()
oneflow._oneflow_internal.autograd.set_grad_enabled(is_train)

def __call__(self, func):
# recover grad mode set in __init__
oneflow._oneflow_internal.autograd.set_grad_enabled(self.prev_mode)

def wrapper(*args, **kwargs):
with AutoGradMode(self.is_train):
return func(*args, **kwargs)

return wrapper

def __enter__(self):
# recover grad mode set in __init__
oneflow._oneflow_internal.autograd.set_grad_enabled(self.prev_mode)
self.grad_mode = AutoGradMode(self.is_train)
return self

Expand Down
60 changes: 44 additions & 16 deletions python/oneflow/test/modules/test_autograd_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,55 @@ def func():
test_case.assertTrue(flow.is_grad_enabled())

def test_set_grad_enabled(test_case):
with flow.set_grad_enabled(True):
test_case.assertTrue(flow.is_grad_enabled())
test_case.assertTrue(flow.is_grad_enabled())
def assert_grad_mode(mode):
if mode:
test_case.assertTrue(flow.is_grad_enabled())
else:
test_case.assertFalse(flow.is_grad_enabled())

@flow.set_grad_enabled(True)
def func():
test_case.assertTrue(flow.is_grad_enabled())
def get_decorater_func_with_mode(mode):
@flow.set_grad_enabled(mode)
def func():
assert_grad_mode(mode)

func()
test_case.assertTrue(flow.is_grad_enabled())
return func

with flow.set_grad_enabled(False):
test_case.assertFalse(flow.is_grad_enabled())
test_case.assertTrue(flow.is_grad_enabled())
def get_decorater_context_func_with_mode(dec_mode, ctx_mode):
@flow.set_grad_enabled(dec_mode)
def func():
assert_grad_mode(dec_mode)
with flow.set_grad_enabled(ctx_mode):
assert_grad_mode(ctx_mode)
assert_grad_mode(dec_mode)

@flow.set_grad_enabled(False)
def func():
test_case.assertFalse(flow.is_grad_enabled())
return func

func()
test_case.assertTrue(flow.is_grad_enabled())
flow.set_grad_enabled(False)
assert_grad_mode(False)

with flow.set_grad_enabled(True):
assert_grad_mode(True)
flow.set_grad_enabled(False)
assert_grad_mode(False)
func = get_decorater_func_with_mode(True)
func()
assert_grad_mode(False)

flow.set_grad_enabled(True)
assert_grad_mode(True)

with flow.set_grad_enabled(False):
assert_grad_mode(False)
flow.set_grad_enabled(True)
assert_grad_mode(True)
func = get_decorater_func_with_mode(False)
func()
assert_grad_mode(True)

get_decorater_context_func_with_mode(True, True)()
get_decorater_context_func_with_mode(True, False)()
get_decorater_context_func_with_mode(False, True)()
get_decorater_context_func_with_mode(False, False)()


if __name__ == "__main__":
Expand Down

0 comments on commit 37ec204

Please sign in to comment.