Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix flow.set_grad_mode when directly calling #10059

Merged
merged 10 commits into from
Mar 30, 2023
1 change: 1 addition & 0 deletions oneflow/api/python/autograd/autograd_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ ONEFLOW_API_PYBIND11_MODULE("autograd", m) {
.def("__exit__", [](const AutoGradMode& no_grad_obj, const py::object& type,
const py::object& value, const py::object& traceback) {});
m.def("is_grad_enabled", &GradMode::is_enabled);
m.def("set_grad_enabled", &GradMode::set_enabled);
}

} // namespace autograd
Expand Down
7 changes: 7 additions & 0 deletions python/oneflow/autograd/autograd_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,22 @@ class set_grad_enabled:

def __init__(self, is_train=True):
self.is_train = is_train
self.prev_mode = is_grad_enabled()
oneflow._oneflow_internal.autograd.set_grad_enabled(is_train)

def __call__(self, func):
# recover grad mode set in __init__
oneflow._oneflow_internal.autograd.set_grad_enabled(self.prev_mode)

def wrapper(*args, **kwargs):
with AutoGradMode(self.is_train):
return func(*args, **kwargs)

return wrapper

def __enter__(self):
# recover grad mode set in __init__
oneflow._oneflow_internal.autograd.set_grad_enabled(self.prev_mode)
self.grad_mode = AutoGradMode(self.is_train)
return self

Expand Down
60 changes: 44 additions & 16 deletions python/oneflow/test/modules/test_autograd_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,55 @@ def func():
test_case.assertTrue(flow.is_grad_enabled())

def test_set_grad_enabled(test_case):
with flow.set_grad_enabled(True):
test_case.assertTrue(flow.is_grad_enabled())
test_case.assertTrue(flow.is_grad_enabled())
def assert_grad_mode(mode):
if mode:
test_case.assertTrue(flow.is_grad_enabled())
else:
test_case.assertFalse(flow.is_grad_enabled())

@flow.set_grad_enabled(True)
def func():
test_case.assertTrue(flow.is_grad_enabled())
def get_decorater_func_with_mode(mode):
@flow.set_grad_enabled(mode)
def func():
assert_grad_mode(mode)

func()
test_case.assertTrue(flow.is_grad_enabled())
return func

with flow.set_grad_enabled(False):
test_case.assertFalse(flow.is_grad_enabled())
test_case.assertTrue(flow.is_grad_enabled())
def get_decorater_context_func_with_mode(dec_mode, ctx_mode):
@flow.set_grad_enabled(dec_mode)
def func():
assert_grad_mode(dec_mode)
with flow.set_grad_enabled(ctx_mode):
assert_grad_mode(ctx_mode)
assert_grad_mode(dec_mode)

@flow.set_grad_enabled(False)
def func():
test_case.assertFalse(flow.is_grad_enabled())
return func

func()
test_case.assertTrue(flow.is_grad_enabled())
flow.set_grad_enabled(False)
assert_grad_mode(False)

with flow.set_grad_enabled(True):
assert_grad_mode(True)
flow.set_grad_enabled(False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个类的用法比较多,可以当装饰器、with域、也可以直接调用,可以考虑在里面测试一下装饰器的用法,保证 __call__ 方法调用的时候恢复的 prev_mode 数据是正确的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个类的用法比较多,可以当装饰器、with域、也可以直接调用,可以考虑在里面测试一下装饰器的用法,保证 __call__ 方法调用的时候恢复的 prev_mode 数据是正确的

加上了

assert_grad_mode(False)
func = get_decorater_func_with_mode(True)
func()
assert_grad_mode(False)

flow.set_grad_enabled(True)
assert_grad_mode(True)

with flow.set_grad_enabled(False):
assert_grad_mode(False)
flow.set_grad_enabled(True)
assert_grad_mode(True)
func = get_decorater_func_with_mode(False)
func()
assert_grad_mode(True)

get_decorater_context_func_with_mode(True, True)()
get_decorater_context_func_with_mode(True, False)()
get_decorater_context_func_with_mode(False, True)()
get_decorater_context_func_with_mode(False, False)()


if __name__ == "__main__":
Expand Down