@@ -58,16 +58,28 @@ def forward(self, x: Tensor) -> Tensor:
58
58
return self .out_layer (self .silu (self .in_layer (x )))
59
59
60
60
61
+ def rms_norm (x , normalized_shape , weight , eps ):
62
+ if hasattr (torch , 'rms_norm' ): # torch 2.4
63
+ return torch .rms_norm (x , normalized_shape , weight , eps )
64
+
65
+ if x .dtype in [torch .bfloat16 , torch .float32 ]:
66
+ n = torch .rsqrt (torch .mean (x ** 2 , dim = - 1 , keepdim = True ) + eps ) * weight
67
+ else :
68
+ n = torch .rsqrt (torch .mean (x .float () ** 2 , dim = - 1 , keepdim = True ) + eps ).to (x .dtype ) * weight
69
+ return x * n
70
+
71
+
61
72
class RMSNorm (torch .nn .Module ):
62
73
def __init__ (self , dim : int , dtype = None , device = None ):
63
74
super ().__init__ ()
64
- self .scale = nn .Parameter (torch .empty ((dim ), dtype = dtype , device = device ))
75
+ self .scale = nn .Parameter (torch .ones ((dim ), dtype = dtype , device = device ))
76
+ self .normalized_shape = [dim ]
65
77
66
78
def forward (self , x : Tensor ):
67
- x_dtype = x .dtype
68
- x = x . float ( )
69
- rrms = torch . rsqrt ( torch . mean ( x ** 2 , dim = - 1 , keepdim = True ) + 1e-6 )
70
- return ( x * rrms ). to ( dtype = x_dtype ) * self .scale
79
+ if self . scale . dtype ! = x .dtype :
80
+ self . scale = nn . Parameter ( self . scale . to ( dtype = x . dtype ), requires_grad = x . requires_grad )
81
+
82
+ return rms_norm ( x , self . normalized_shape , self .scale , 1e-6 )
71
83
72
84
73
85
class QKNorm (torch .nn .Module ):
@@ -98,7 +110,9 @@ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=N
98
110
99
111
def forward (self , x : Tensor , pe : Tensor ) -> Tensor :
100
112
qkv = self .qkv (x )
101
- q , k , v = rearrange (qkv , "B L (K H D) -> K B H L D" , K = 3 , H = self .num_heads )
113
+ #q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
114
+ B , L , _ = qkv .shape
115
+ q , k , v = qkv .view (B , L , 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
102
116
q , k = self .norm (q , k , v )
103
117
x = attention (q , k , v , pe = pe )
104
118
x = self .proj (x )
@@ -165,14 +179,18 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
165
179
img_modulated = self .img_norm1 (img )
166
180
img_modulated = (1 + img_mod1 .scale ) * img_modulated + img_mod1 .shift
167
181
img_qkv = self .img_attn .qkv (img_modulated )
168
- img_q , img_k , img_v = rearrange (img_qkv , "B L (K H D) -> K B H L D" , K = 3 , H = self .num_heads )
182
+ #img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
183
+ B , L , _ = img_qkv .shape
184
+ img_q , img_k , img_v = img_qkv .view (B , L , 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
169
185
img_q , img_k = self .img_attn .norm (img_q , img_k , img_v )
170
186
171
187
# prepare txt for attention
172
188
txt_modulated = self .txt_norm1 (txt )
173
189
txt_modulated = (1 + txt_mod1 .scale ) * txt_modulated + txt_mod1 .shift
174
190
txt_qkv = self .txt_attn .qkv (txt_modulated )
175
- txt_q , txt_k , txt_v = rearrange (txt_qkv , "B L (K H D) -> K B H L D" , K = 3 , H = self .num_heads )
191
+ #txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
192
+ B , L , _ = txt_qkv .shape
193
+ txt_q , txt_k , txt_v = txt_qkv .view (B , L , 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
176
194
txt_q , txt_k = self .txt_attn .norm (txt_q , txt_k , txt_v )
177
195
178
196
# run actual attention
@@ -238,7 +256,9 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
238
256
x_mod = (1 + mod .scale ) * self .pre_norm (x ) + mod .shift
239
257
qkv , mlp = torch .split (self .linear1 (x_mod ), [3 * self .hidden_size , self .mlp_hidden_dim ], dim = - 1 )
240
258
241
- q , k , v = rearrange (qkv , "B L (K H D) -> K B H L D" , K = 3 , H = self .num_heads )
259
+ #q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
260
+ B , L , _ = qkv .shape
261
+ q , k , v = qkv .view (B , L , 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
242
262
q , k = self .norm (q , k , v )
243
263
244
264
# compute attention
0 commit comments