3
3
from collections import namedtuple
4
4
import enum
5
5
6
+ import torch .nn as nn
7
+ import torch .nn .functional as F
8
+
6
9
from modules import sd_models , cache , errors , hashes , shared
7
10
8
11
NetworkWeights = namedtuple ('NetworkWeights' , ['network_key' , 'sd_key' , 'w' , 'sd_module' ])
@@ -115,6 +118,29 @@ def __init__(self, net: Network, weights: NetworkWeights):
115
118
if hasattr (self .sd_module , 'weight' ):
116
119
self .shape = self .sd_module .weight .shape
117
120
121
+ self .ops = None
122
+ self .extra_kwargs = {}
123
+ if isinstance (self .sd_module , nn .Conv2d ):
124
+ self .ops = F .conv2d
125
+ self .extra_kwargs = {
126
+ 'stride' : self .sd_module .stride ,
127
+ 'padding' : self .sd_module .padding
128
+ }
129
+ elif isinstance (self .sd_module , nn .Linear ):
130
+ self .ops = F .linear
131
+ elif isinstance (self .sd_module , nn .LayerNorm ):
132
+ self .ops = F .layer_norm
133
+ self .extra_kwargs = {
134
+ 'normalized_shape' : self .sd_module .normalized_shape ,
135
+ 'eps' : self .sd_module .eps
136
+ }
137
+ elif isinstance (self .sd_module , nn .GroupNorm ):
138
+ self .ops = F .group_norm
139
+ self .extra_kwargs = {
140
+ 'num_groups' : self .sd_module .num_groups ,
141
+ 'eps' : self .sd_module .eps
142
+ }
143
+
118
144
self .dim = None
119
145
self .bias = weights .w .get ("bias" )
120
146
self .alpha = weights .w ["alpha" ].item () if "alpha" in weights .w else None
@@ -137,7 +163,7 @@ def calc_scale(self):
137
163
def finalize_updown (self , updown , orig_weight , output_shape , ex_bias = None ):
138
164
if self .bias is not None :
139
165
updown = updown .reshape (self .bias .shape )
140
- updown += self .bias .to (orig_weight .device , dtype = orig_weight .dtype )
166
+ updown += self .bias .to (orig_weight .device , dtype = updown .dtype )
141
167
updown = updown .reshape (output_shape )
142
168
143
169
if len (output_shape ) == 4 :
@@ -155,5 +181,10 @@ def calc_updown(self, target):
155
181
raise NotImplementedError ()
156
182
157
183
def forward (self , x , y ):
158
- raise NotImplementedError ()
184
+ """A general forward implementation for all modules"""
185
+ if self .ops is None :
186
+ raise NotImplementedError ()
187
+ else :
188
+ updown , ex_bias = self .calc_updown (self .sd_module .weight )
189
+ return y + self .ops (x , weight = updown , bias = ex_bias , ** self .extra_kwargs )
159
190
0 commit comments