Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unset lazy&verbose env var when disabling mock torch by command line #9970

Merged
merged 10 commits into from
Mar 12, 2023
16 changes: 11 additions & 5 deletions python/oneflow/mock_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,17 @@ def __getattr__(self, name: str) -> Any:
return [attr for attr in dir(self.module) if not attr.startswith("_")]
new_name = self.module.__name__ + "." + name
if _importer.lazy:
if _importer.verbose:
print(
f"{new_name} is not found in oneflow, use dummy object as fallback."
)
return DummyModule(new_name)
blacklist = ["scaled_dot_product_attention"]
if name in blacklist:
if _importer.verbose:
print(f'"{new_name}" is in blacklist, raise AttributeError')
raise AttributeError(new_name + error_msg)
else:
if _importer.verbose:
print(
f'"{new_name}" is not found in oneflow, use dummy object as fallback.'
)
return DummyModule(new_name)
else:
raise AttributeError(new_name + error_msg)
attr = getattr(self.module, name)
Expand Down
4 changes: 3 additions & 1 deletion python/oneflow/mock_torch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def main():
paths = os.environ["PYTHONPATH"].rstrip(":").split(":")
paths = [x for x in paths if x != str(torch_env)]
path = ":".join(paths)
print("export PYTHONPATH=" + path)
print(
f"export PYTHONPATH={path}; unset ONEFLOW_MOCK_TORCH_LAZY; unset ONEFLOW_MOCK_TORCH_VERBOSE"
)


if __name__ == "__main__":
Expand Down
10 changes: 10 additions & 0 deletions python/oneflow/test/misc/test_mock_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ def test_mock_lazy_for_loop(test_case):
for _ in torch.not_exist:
pass

def test_blacklist(test_case):
with mock.enable(lazy=True):
import torch
import torch.nn.functional as F

test_case.assertFalse(hasattr(F, "scaled_dot_product_attention"))
test_case.assertFalse(
hasattr(torch.nn.functional, "scaled_dot_product_attention")
)


# MUST use pytest to run this test
def test_verbose(capsys):
Expand Down