Skip to content

Commit ac9b19d

Browse files
committed
optimize
* replace rearrange to view AUTOMATIC1111#15804 * see also lllyasviel/stable-diffusion-webui-forge@79adfa8 * conditional use torch.rms_norm for torch 2.4 * fix RMSNorm() for clear: use torch.ones()
1 parent bda8779 commit ac9b19d

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

modules/models/flux/math.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
77
q, k = apply_rope(q, k, pe)
88

99
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
10-
x = rearrange(x, "B H L D -> B L (H D)")
10+
#x = rearrange(x, "B H L D -> B L (H D)")
11+
B, H, L, D = x.shape
12+
x = x.permute(0, 2, 1, 3).contiguous().view(B, L, H * D)
1113

1214
return x
1315

@@ -17,9 +19,13 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
1719
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
1820
omega = 1.0 / (theta**scale)
1921
out = torch.einsum("...n,d->...nd", pos, omega)
22+
#out = pos.unsqueeze(-1) * omega.unsqueeze(0)
23+
2024
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
21-
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
22-
return out.float()
25+
#out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
26+
b, n, d, _ = out.shape
27+
out = out.view(b, n, d, 2, 2)
28+
return out.to(dtype=torch.float32, device=pos.device)
2329

2430

2531
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):

modules/models/flux/modules/layers.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,28 @@ def forward(self, x: Tensor) -> Tensor:
5858
return self.out_layer(self.silu(self.in_layer(x)))
5959

6060

61+
def rms_norm(x, normalized_shape, weight, eps):
62+
if hasattr(torch, 'rms_norm'): # torch 2.4
63+
return torch.rms_norm(x, normalized_shape, weight, eps)
64+
65+
if x.dtype in [torch.bfloat16, torch.float32]:
66+
n = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps) * weight
67+
else:
68+
n = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps).to(x.dtype) * weight
69+
return x * n
70+
71+
6172
class RMSNorm(torch.nn.Module):
6273
def __init__(self, dim: int, dtype=None, device=None):
6374
super().__init__()
64-
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
75+
self.scale = nn.Parameter(torch.ones((dim), dtype=dtype, device=device))
76+
self.normalized_shape = [dim]
6577

6678
def forward(self, x: Tensor):
67-
x_dtype = x.dtype
68-
x = x.float()
69-
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
70-
return (x * rrms).to(dtype=x_dtype) * self.scale
79+
if self.scale.dtype != x.dtype:
80+
self.scale = nn.Parameter(self.scale.to(dtype=x.dtype), requires_grad=x.requires_grad)
81+
82+
return rms_norm(x, self.normalized_shape, self.scale, 1e-6)
7183

7284

7385
class QKNorm(torch.nn.Module):
@@ -98,7 +110,9 @@ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=N
98110

99111
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
100112
qkv = self.qkv(x)
101-
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
113+
#q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
114+
B, L, _ = qkv.shape
115+
q, k, v = qkv.view(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
102116
q, k = self.norm(q, k, v)
103117
x = attention(q, k, v, pe=pe)
104118
x = self.proj(x)
@@ -165,14 +179,18 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
165179
img_modulated = self.img_norm1(img)
166180
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
167181
img_qkv = self.img_attn.qkv(img_modulated)
168-
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
182+
#img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
183+
B, L, _ = img_qkv.shape
184+
img_q, img_k, img_v = img_qkv.view(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
169185
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
170186

171187
# prepare txt for attention
172188
txt_modulated = self.txt_norm1(txt)
173189
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
174190
txt_qkv = self.txt_attn.qkv(txt_modulated)
175-
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
191+
#txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
192+
B, L, _ = txt_qkv.shape
193+
txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
176194
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
177195

178196
# run actual attention
@@ -238,7 +256,9 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
238256
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
239257
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
240258

241-
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
259+
#q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
260+
B, L, _ = qkv.shape
261+
q, k, v = qkv.view(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
242262
q, k = self.norm(q, k, v)
243263

244264
# compute attention

0 commit comments

Comments
 (0)