Skip to content

Commit a6c384b

Browse files
Merge pull request AUTOMATIC1111#16144 from akx/bump-spandrel
Bump spandrel to 0.3.4
2 parents b282b47 + f8fb74b commit a6c384b

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

modules/gfpgan_model.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,11 @@ def load_net(self) -> torch.Module:
3636
ext_filter=['.pth'],
3737
):
3838
if 'GFPGAN' in os.path.basename(model_path):
39-
model = modelloader.load_spandrel_model(
39+
return modelloader.load_spandrel_model(
4040
model_path,
4141
device=self.get_device(),
4242
expected_architecture='GFPGAN',
4343
).model
44-
model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
45-
return model
4644
raise ValueError("No GFPGAN model found")
4745

4846
def restore(self, np_image):

modules/modelloader.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,27 @@ def load_upscalers():
139139
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
140140
)
141141

142+
# None: not loaded, False: failed to load, True: loaded
143+
_spandrel_extra_init_state = None
144+
145+
146+
def _init_spandrel_extra_archs() -> None:
147+
"""
148+
Try to initialize `spandrel_extra_archs` (exactly once).
149+
"""
150+
global _spandrel_extra_init_state
151+
if _spandrel_extra_init_state is not None:
152+
return
153+
154+
try:
155+
import spandrel
156+
import spandrel_extra_arches
157+
spandrel.MAIN_REGISTRY.add(*spandrel_extra_arches.EXTRA_REGISTRY)
158+
_spandrel_extra_init_state = True
159+
except Exception:
160+
logger.warning("Failed to load spandrel_extra_arches", exc_info=True)
161+
_spandrel_extra_init_state = False
162+
142163

143164
def load_spandrel_model(
144165
path: str | os.PathLike,
@@ -148,11 +169,16 @@ def load_spandrel_model(
148169
dtype: str | torch.dtype | None = None,
149170
expected_architecture: str | None = None,
150171
) -> spandrel.ModelDescriptor:
172+
global _spandrel_extra_init_state
173+
151174
import spandrel
175+
_init_spandrel_extra_archs()
176+
152177
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
153-
if expected_architecture and model_descriptor.architecture != expected_architecture:
178+
arch = model_descriptor.architecture
179+
if expected_architecture and arch.name != expected_architecture:
154180
logger.warning(
155-
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
181+
f"Model {path!r} is not a {expected_architecture!r} model (got {arch.name!r})",
156182
)
157183
half = False
158184
if prefer_half:
@@ -166,6 +192,6 @@ def load_spandrel_model(
166192
model_descriptor.model.eval()
167193
logger.debug(
168194
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
169-
model_descriptor, path, device, half, dtype,
195+
arch, path, device, half, dtype,
170196
)
171197
return model_descriptor

requirements_versions.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ pytorch_lightning==1.9.4
2424
resize-right==0.0.2
2525
safetensors==0.4.2
2626
scikit-image==0.21.0
27-
spandrel==0.1.6
27+
spandrel==0.3.4
28+
spandrel-extra-arches==0.1.1
2829
tomesd==0.1.3
2930
torch
3031
torchdiffeq==0.2.3

0 commit comments

Comments
 (0)