Skip to content

Commit 5b2a60b

Browse files
committed
initial SD3 support
1 parent a7116aa commit 5b2a60b

14 files changed

+333
-44
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
150150
## Credits
151151
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
152152

153-
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
153+
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers, https://github.com/mcmonkey4eva/sd3-ref
154154
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
155155
- Spandrel - https://github.com/chaiNNer-org/spandrel implementing
156156
- GFPGAN - https://github.com/TencentARC/GFPGAN.git

configs/sd3-inference.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
model:
2+
target: modules.models.sd3.sd3_model.SD3Inferencer
3+
params:
4+
shift: 3
5+
state_dict: null

extensions-builtin/Lora/networks.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model):
130130
network_layer_mapping[network_name] = module
131131
module.network_layer_name = network_name
132132
else:
133-
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
133+
cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
134+
135+
for name, module in cond_stage_model.named_modules():
134136
network_name = name.replace(".", "_")
135137
network_layer_mapping[network_name] = module
136138
module.network_layer_name = network_name

modules/models/sd3/mmdit.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import torch
77
import torch.nn as nn
88
from einops import rearrange, repeat
9-
from other_impls import attention, Mlp
9+
from modules.models.sd3.other_impls import attention, Mlp
10+
1011

1112
class PatchEmbed(nn.Module):
1213
""" 2D Image to Patch Embedding"""

modules/models/sd3/sd3_impls.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
### Impls of the SD3 core diffusion model and VAE
22

33
import torch, math, einops
4-
from mmdit import MMDiT
4+
from modules.models.sd3.mmdit import MMDiT
55
from PIL import Image
66

77

@@ -46,16 +46,16 @@ def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
4646

4747
class BaseModel(torch.nn.Module):
4848
"""Wrapper around the core MM-DiT model"""
49-
def __init__(self, shift=1.0, device=None, dtype=torch.float32, file=None, prefix=""):
49+
def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""):
5050
super().__init__()
5151
# Important configuration values can be quickly determined by checking shapes in the source file
5252
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
53-
patch_size = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[2]
54-
depth = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[0] // 64
55-
num_patches = file.get_tensor(f"{prefix}pos_embed").shape[1]
53+
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
54+
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
55+
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
5656
pos_embed_max_size = round(math.sqrt(num_patches))
57-
adm_in_channels = file.get_tensor(f"{prefix}y_embedder.mlp.0.weight").shape[1]
58-
context_shape = file.get_tensor(f"{prefix}context_embedder.weight").shape
57+
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
58+
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
5959
context_embedder_config = {
6060
"target": "torch.nn.Linear",
6161
"params": {

modules/models/sd3/sd3_model.py

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import contextlib
2+
import os
3+
from typing import Mapping
4+
5+
import safetensors
6+
import torch
7+
8+
import k_diffusion
9+
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
10+
from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
11+
12+
from modules import shared, modelloader, devices
13+
14+
CLIPG_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_g.safetensors"
15+
CLIPG_CONFIG = {
16+
"hidden_act": "gelu",
17+
"hidden_size": 1280,
18+
"intermediate_size": 5120,
19+
"num_attention_heads": 20,
20+
"num_hidden_layers": 32,
21+
}
22+
23+
CLIPL_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_l.safetensors"
24+
CLIPL_CONFIG = {
25+
"hidden_act": "quick_gelu",
26+
"hidden_size": 768,
27+
"intermediate_size": 3072,
28+
"num_attention_heads": 12,
29+
"num_hidden_layers": 12,
30+
}
31+
32+
T5_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/t5xxl_fp16.safetensors"
33+
T5_CONFIG = {
34+
"d_ff": 10240,
35+
"d_model": 4096,
36+
"num_heads": 64,
37+
"num_layers": 24,
38+
"vocab_size": 32128,
39+
}
40+
41+
42+
class SafetensorsMapping(Mapping):
43+
def __init__(self, file):
44+
self.file = file
45+
46+
def __len__(self):
47+
return len(self.file.keys())
48+
49+
def __iter__(self):
50+
for key in self.file.keys():
51+
yield key
52+
53+
def __getitem__(self, key):
54+
return self.file.get_tensor(key)
55+
56+
57+
class SD3Cond(torch.nn.Module):
58+
def __init__(self, *args, **kwargs):
59+
super().__init__(*args, **kwargs)
60+
61+
self.tokenizer = SD3Tokenizer()
62+
63+
with torch.no_grad():
64+
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
65+
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=torch.float32, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
66+
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
67+
68+
self.weights_loaded = False
69+
70+
def forward(self, prompts: list[str]):
71+
res = []
72+
73+
for prompt in prompts:
74+
tokens = self.tokenizer.tokenize_with_weights(prompt)
75+
l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
76+
g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
77+
t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
78+
lg_out = torch.cat([l_out, g_out], dim=-1)
79+
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
80+
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
81+
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
82+
83+
res.append({
84+
'crossattn': lgt_out[0].to(devices.device),
85+
'vector': vector_out[0].to(devices.device),
86+
})
87+
88+
return res
89+
90+
def load_weights(self):
91+
if self.weights_loaded:
92+
return
93+
94+
clip_path = os.path.join(shared.models_path, "CLIP")
95+
96+
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
97+
with safetensors.safe_open(clip_g_file, framework="pt") as file:
98+
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
99+
100+
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
101+
with safetensors.safe_open(clip_l_file, framework="pt") as file:
102+
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
103+
104+
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
105+
with safetensors.safe_open(t5_file, framework="pt") as file:
106+
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
107+
108+
self.weights_loaded = True
109+
110+
def encode_embedding_init_text(self, init_text, nvpt):
111+
return torch.tensor([[0]], device=devices.device) # XXX
112+
113+
114+
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
115+
def __init__(self, inner_model, sigmas):
116+
super().__init__(sigmas, quantize=shared.opts.enable_quantization)
117+
self.inner_model = inner_model
118+
119+
def forward(self, input, sigma, **kwargs):
120+
return self.inner_model.apply_model(input, sigma, **kwargs)
121+
122+
123+
class SD3Inferencer(torch.nn.Module):
124+
def __init__(self, state_dict, shift=3, use_ema=False):
125+
super().__init__()
126+
127+
self.shift = shift
128+
129+
with torch.no_grad():
130+
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)
131+
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
132+
self.first_stage_model.dtype = self.model.diffusion_model.dtype
133+
134+
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
135+
136+
self.cond_stage_model = SD3Cond()
137+
self.cond_stage_key = 'txt'
138+
139+
self.parameterization = "eps"
140+
self.model.conditioning_key = "crossattn"
141+
142+
self.latent_format = SD3LatentFormat()
143+
self.latent_channels = 16
144+
145+
def after_load_weights(self):
146+
self.cond_stage_model.load_weights()
147+
148+
def ema_scope(self):
149+
return contextlib.nullcontext()
150+
151+
def get_learned_conditioning(self, batch: list[str]):
152+
return self.cond_stage_model(batch)
153+
154+
def apply_model(self, x, t, cond):
155+
return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
156+
157+
def decode_first_stage(self, latent):
158+
latent = self.latent_format.process_out(latent)
159+
return self.first_stage_model.decode(latent)
160+
161+
def encode_first_stage(self, image):
162+
latent = self.first_stage_model.encode(image)
163+
return self.latent_format.process_in(latent)
164+
165+
def create_denoiser(self):
166+
return SD3Denoiser(self, self.model.model_sampling.sigmas)

modules/processing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
942942
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
943943
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
944944

945-
p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
945+
latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
946+
p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
946947

947948
if p.scripts is not None:
948949
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)

0 commit comments

Comments
 (0)