Skip to content

Commit 3f9e82f

Browse files
author
The etils Authors
committed
Support LazyModule setattr and delattr
PiperOrigin-RevId: 737044617
1 parent caec352 commit 3f9e82f

File tree

3 files changed

+53
-2
lines changed

3 files changed

+53
-2
lines changed

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ Changelog follow https://keepachangelog.com/ format.
88

99
## [Unreleased]
1010

11-
* `ecolab.adhoc`:
11+
* `ecolab.adhoc`:
1212
* Better error message for adhoc invalidate with `epy.reraise`
13+
* `epy`:
14+
* Support `__setattr__` and `__delattr__` on LazyModules.
1315

1416
## [1.12.2] - 2025-03-10
1517

etils/epy/lazy_imports_utils.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import dataclasses
3030
import functools
3131
import importlib
32+
import inspect
3233
import sys
3334
import threading
3435
import types
@@ -40,6 +41,8 @@
4041

4142
_ErrorCallback = Callable[[Exception], None]
4243
_SuccessCallback = Callable[[str], None]
44+
# Eagerly resolve this import since it's needed in setattr.
45+
_getattr_static = inspect.getattr_static
4346

4447
# Store a lock per module to avoid problems with multiple threads trying to
4548
# import the same module at the same time.
@@ -55,6 +58,7 @@ class LazyModule:
5558
error_callback: str | _ErrorCallback | None
5659
success_callback: _SuccessCallback | None
5760
_submodules: dict[str, LazyModule] = dataclasses.field(default_factory=dict)
61+
_initialized: bool = dataclasses.field(default=False, init=False)
5862

5963
def __post_init__(self):
6064
if self.adhoc_kwargs is not None:
@@ -66,6 +70,7 @@ def __post_init__(self):
6670
self.adhoc_kwargs.pop("reload_workspace", None)
6771
self.adhoc_kwargs.pop("cell_autoreload", None)
6872
self.adhoc_kwargs.pop("restrict_reload", None)
73+
self._initialized = True
6974

7075
@functools.cached_property
7176
def _module(self) -> types.ModuleType:
@@ -117,7 +122,20 @@ def __getattr__(self, name: str) -> Any:
117122
else:
118123
return getattr(self._module, name)
119124

120-
# TODO(epot): Also support __setattr__
125+
def __setattr__(self, name: str, value: Any) -> None:
126+
# Avoid regular `__getattr__` path during initialization.
127+
if _getattr_static(self, "_initialized"):
128+
# Trigger import first to overwrite the old attribute if it exists.
129+
setattr(self._module, name, value)
130+
else:
131+
super().__setattr__(name, value)
132+
133+
def __delattr__(self, name: str):
134+
if name in self._submodules:
135+
del self._submodules[name]
136+
# Always delete from underlying module so that lazy import doesn't replace
137+
# a deleted attribute.
138+
delattr(self._module, name)
121139

122140

123141
def _register_submodule(module: LazyModule, name: str) -> LazyModule:

etils/epy/lazy_imports_utils_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,37 @@ def test_import_with_alias():
6464
assert 'tensorflow_datasets' in sys.modules
6565

6666

67+
def test_setattr():
68+
assert 'tensorflow_datasets' not in sys.modules
69+
with epy.lazy_imports():
70+
import tensorflow_datasets as tfds # pylint: disable=g-import-not-at-top
71+
assert 'tensorflow_datasets' not in sys.modules
72+
tfds.features = 'foo'
73+
assert 'tensorflow_datasets' in sys.modules
74+
assert tfds.features == 'foo'
75+
76+
77+
def test_delattr():
78+
assert 'tensorflow_datasets' not in sys.modules
79+
with epy.lazy_imports():
80+
import tensorflow_datasets as tfds # pylint: disable=g-import-not-at-top
81+
assert 'tensorflow_datasets' not in sys.modules
82+
del tfds.features
83+
assert 'tensorflow_datasets' in sys.modules
84+
with pytest.raises(AttributeError):
85+
_ = tfds.features
86+
87+
88+
def test_delattr_submodule():
89+
assert 'tensorflow_datasets' not in sys.modules
90+
with epy.lazy_imports():
91+
import tensorflow_datasets.core # pylint: disable=g-import-not-at-top
92+
assert 'tensorflow_datasets' not in sys.modules
93+
del tensorflow_datasets.core
94+
with pytest.raises(AttributeError):
95+
_ = tensorflow_datasets.core
96+
97+
6798
def test_error_callback():
6899
success_callback = mock.MagicMock()
69100
error_callback = mock.MagicMock()

0 commit comments

Comments
 (0)