Skip to content

Commit a70dfb6

Browse files
committed
change import statements for AUTOMATIC1111#14478
1 parent be5f1ac commit a70dfb6

7 files changed

+14
-17
lines changed

modules/devices.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66
from modules import errors, shared
7-
from modules.torch_utils import get_param
7+
from modules import torch_utils
88

99
if sys.platform == "darwin":
1010
from modules import mac_specific
@@ -132,7 +132,7 @@ def cond_cast_float(input):
132132

133133

134134
def manual_cast_forward(self, *args, **kwargs):
135-
org_dtype = get_param(self).dtype
135+
org_dtype = torch_utils.get_param(self).dtype
136136
self.to(dtype)
137137
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
138138
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}

modules/interrogate.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from torchvision import transforms
1111
from torchvision.transforms.functional import InterpolationMode
1212

13-
from modules import devices, paths, shared, lowvram, modelloader, errors
14-
from modules.torch_utils import get_param
13+
from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
1514

1615
blip_image_eval_size = 384
1716
clip_model_name = 'ViT-L/14'
@@ -132,7 +131,7 @@ def load(self):
132131

133132
self.clip_model = self.clip_model.to(devices.device_interrogate)
134133

135-
self.dtype = get_param(self.clip_model).dtype
134+
self.dtype = torch_utils.get_param(self.clip_model).dtype
136135

137136
def send_clip_to_ram(self):
138137
if not shared.opts.interrogate_keep_models_in_memory:

modules/sd_models_xl.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sgm.modules.diffusionmodules.denoiser_scaling
77
import sgm.modules.diffusionmodules.discretizer
88
from modules import devices, shared, prompt_parser
9-
from modules.torch_utils import get_param
9+
from modules import torch_utils
1010

1111

1212
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
@@ -91,7 +91,7 @@ def get_target_prompt_token_count(self, token_count):
9191
def extend_sdxl(model):
9292
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
9393

94-
dtype = get_param(model.model.diffusion_model).dtype
94+
dtype = torch_utils.get_param(model.model.diffusion_model).dtype
9595
model.model.diffusion_model.dtype = dtype
9696
model.model.conditioning_key = 'crossattn'
9797
model.cond_stage_key = 'txt'

modules/upscaler_utils.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import tqdm
77
from PIL import Image
88

9-
from modules import images, shared
10-
from modules.torch_utils import get_param
9+
from modules import images, shared, torch_utils
1110

1211
logger = logging.getLogger(__name__)
1312

@@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image):
1817
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
1918
img = torch.from_numpy(img).float()
2019

21-
param = get_param(model)
20+
param = torch_utils.get_param(model)
2221
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
2322

2423
with torch.no_grad():

modules/xlmr.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from transformers import XLMRobertaModel,XLMRobertaTokenizer
66
from typing import Optional
77

8-
from modules.torch_utils import get_param
8+
from modules import torch_utils
99

1010

1111
class BertSeriesConfig(BertConfig):
@@ -65,7 +65,7 @@ def __init__(self, config=None, **kargs):
6565
self.post_init()
6666

6767
def encode(self,c):
68-
device = get_param(self).device
68+
device = torch_utils.get_param(self).device
6969
text = self.tokenizer(c,
7070
truncation=True,
7171
max_length=77,

modules/xlmr_m18.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
55
from transformers import XLMRobertaModel,XLMRobertaTokenizer
66
from typing import Optional
7-
8-
from modules.torch_utils import get_param
7+
from modules import torch_utils
98

109

1110
class BertSeriesConfig(BertConfig):
@@ -71,7 +70,7 @@ def __init__(self, config=None, **kargs):
7170
self.post_init()
7271

7372
def encode(self,c):
74-
device = get_param(self).device
73+
device = torch_utils.get_param(self).device
7574
text = self.tokenizer(c,
7675
truncation=True,
7776
max_length=77,

test/test_torch_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torch
55

6-
from modules.torch_utils import get_param
6+
from modules import torch_utils
77

88

99
@pytest.mark.parametrize("wrapped", [True, False])
@@ -14,6 +14,6 @@ def test_get_param(wrapped):
1414
if wrapped:
1515
# more or less how spandrel wraps a thing
1616
mod = types.SimpleNamespace(model=mod)
17-
p = get_param(mod)
17+
p = torch_utils.get_param(mod)
1818
assert p.dtype == torch.float16
1919
assert p.device == cpu

0 commit comments

Comments
 (0)