Skip to content

Commit cc3f604

Browse files
committed
Update
1 parent 74ff85a commit cc3f604

8 files changed

+22
-20
lines changed

modules/devices.py

+7
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,16 @@ def torch_gc():
8888
xpu_specific.torch_xpu_gc()
8989

9090
if npu_specific.has_npu:
91+
torch_npu_set_device()
9192
npu_specific.torch_npu_gc()
9293

9394

95+
def torch_npu_set_device():
96+
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
97+
if npu_specific.has_npu:
98+
torch.npu.set_device(0)
99+
100+
94101
def enable_tf32():
95102
if torch.cuda.is_available():
96103

modules/initialize.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,7 @@ def load_model():
143143
by that time, so we apply optimization again.
144144
"""
145145
from modules import devices
146-
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
147-
if devices.npu_specific.has_npu:
148-
import torch
149-
torch.npu.set_device(0)
146+
devices.torch_npu_set_device()
150147

151148
shared.sd_model # noqa: B018
152149

modules/launch_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def prepare_environment():
338338
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
339339
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
340340
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
341+
requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt")
341342

342343
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
343344
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
@@ -421,6 +422,13 @@ def prepare_environment():
421422
run_pip(f"install -r \"{requirements_file}\"", "requirements")
422423
startup_timer.record("install requirements")
423424

425+
if not os.path.isfile(requirements_file_for_npu):
426+
requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)
427+
428+
if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu):
429+
run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu")
430+
startup_timer.record("install requirements_for_npu")
431+
424432
if not args.skip_install:
425433
run_extensions_installers(settings_file=args.ui_settings_file)
426434

modules/npu_specific.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@ def check_for_npu():
88
if importlib.util.find_spec("torch_npu") is None:
99
return False
1010
import torch_npu
11-
torch_npu.npu.set_device(0)
1211

1312
try:
1413
# Will raise a RuntimeError if no NPU is found
15-
_ = torch.npu.device_count()
14+
_ = torch_npu.npu.device_count()
1615
return torch.npu.is_available()
1716
except RuntimeError:
1817
return False
@@ -25,8 +24,6 @@ def get_npu_device_string():
2524

2625

2726
def torch_npu_gc():
28-
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
29-
torch.npu.set_device(0)
3027
with torch.npu.device(get_npu_device_string()):
3128
torch.npu.empty_cache()
3229

modules/textual_inversion/textual_inversion.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,7 @@ def register_embedding_by_name(self, embedding, model, name):
150150
return embedding
151151

152152
def get_expected_shape(self):
153-
# workaround
154-
if devices.npu_specific.has_npu:
155-
import torch
156-
torch.npu.set_device(0)
153+
devices.torch_npu_set_device()
157154
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
158155
return vec.shape[1]
159156

requirements.txt

-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ accelerate
44

55
blendmodes
66
clean-fid
7-
cloudpickle
8-
decorator
97
einops
108
facexlib
119
fastapi>=0.90.1
@@ -26,10 +24,8 @@ resize-right
2624

2725
safetensors
2826
scikit-image>=0.19
29-
synr==0.5.0
3027
tomesd
3128
torch
3229
torchdiffeq
3330
torchsde
34-
tornado
3531
transformers==4.30.2

requirements_npu.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
cloudpickle
2+
decorator
3+
synr==0.5.0
4+
tornado

requirements_versions.txt

-4
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ Pillow==9.5.0
33
accelerate==0.21.0
44
blendmodes==2022
55
clean-fid==0.1.35
6-
cloudpickle==3.0.0
7-
decorator==5.1.1
86
einops==0.4.1
97
facexlib==0.3.0
108
fastapi==0.94.0
@@ -23,12 +21,10 @@ pytorch_lightning==1.9.4
2321
resize-right==0.0.2
2422
safetensors==0.4.2
2523
scikit-image==0.21.0
26-
synr==0.5.0
2724
spandrel==0.1.6
2825
tomesd==0.1.3
2926
torch
3027
torchdiffeq==0.2.3
3128
torchsde==0.2.6
32-
tornado==6.4
3329
transformers==4.30.2
3430
httpx==0.24.1

0 commit comments

Comments
 (0)