Skip to content

Commit

Permalink
unset lazy&verbose env var when disabling mock torch by command line (#…
Browse files Browse the repository at this point in the history
…9970)

1. 修复通过命令行开启 mock 之后环境变量残留导致 api 方式的 mock 参数错误的问题
2. 将 scaled_dot_product_attention 加入 lazy mock 的黑名单,让 hasattr(F,
"scaled_dot_product_attention") 返回 False

---------

Signed-off-by: daquexian <daquexian566@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
3 people authored Mar 12, 2023
1 parent 9c05667 commit 5ab28bb
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
16 changes: 11 additions & 5 deletions python/oneflow/mock_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,17 @@ def __getattr__(self, name: str) -> Any:
return [attr for attr in dir(self.module) if not attr.startswith("_")]
new_name = self.module.__name__ + "." + name
if _importer.lazy:
if _importer.verbose:
print(
f"{new_name} is not found in oneflow, use dummy object as fallback."
)
return DummyModule(new_name)
blacklist = ["scaled_dot_product_attention"]
if name in blacklist:
if _importer.verbose:
print(f'"{new_name}" is in blacklist, raise AttributeError')
raise AttributeError(new_name + error_msg)
else:
if _importer.verbose:
print(
f'"{new_name}" is not found in oneflow, use dummy object as fallback.'
)
return DummyModule(new_name)
else:
raise AttributeError(new_name + error_msg)
attr = getattr(self.module, name)
Expand Down
4 changes: 3 additions & 1 deletion python/oneflow/mock_torch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def main():
paths = os.environ["PYTHONPATH"].rstrip(":").split(":")
paths = [x for x in paths if x != str(torch_env)]
path = ":".join(paths)
print("export PYTHONPATH=" + path)
print(
f"export PYTHONPATH={path}; unset ONEFLOW_MOCK_TORCH_LAZY; unset ONEFLOW_MOCK_TORCH_VERBOSE"
)


if __name__ == "__main__":
Expand Down
10 changes: 10 additions & 0 deletions python/oneflow/test/misc/test_mock_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ def test_mock_lazy_for_loop(test_case):
for _ in torch.not_exist:
pass

def test_blacklist(test_case):
with mock.enable(lazy=True):
import torch
import torch.nn.functional as F

test_case.assertFalse(hasattr(F, "scaled_dot_product_attention"))
test_case.assertFalse(
hasattr(torch.nn.functional, "scaled_dot_product_attention")
)


# MUST use pytest to run this test
def test_verbose(capsys):
Expand Down

0 comments on commit 5ab28bb

Please sign in to comment.