-
Notifications
You must be signed in to change notification settings - Fork 151
/
Copy pathlocon.py
332 lines (297 loc) · 11.3 KB
/
locon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import math
from functools import cache
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import LycorisBaseModule
from ..functional.general import rebuild_tucker
from ..logging import logger
@cache
def log_wd():
return logger.warning(
"Using weight_decompose=True with LoRA (DoRA) will ignore network_dropout."
"Only rank dropout and module dropout will be applied"
)
class LoConModule(LycorisBaseModule):
name = "locon"
support_module = {
"linear",
"conv1d",
"conv2d",
"conv3d",
}
weight_list = [
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
"alpha",
"dora_scale",
]
weight_list_det = ["lora_up.weight"]
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
use_tucker=False,
use_scalar=False,
rank_dropout_scale=False,
weight_decompose=False,
wd_on_out=False,
bypass_mode=None,
rs_lora=False,
**kwargs,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__(
lora_name,
org_module,
multiplier,
dropout,
rank_dropout,
module_dropout,
rank_dropout_scale,
bypass_mode,
)
if self.module_type not in self.support_module:
raise ValueError(f"{self.module_type} is not supported in LoRA/LoCon algo.")
self.lora_dim = lora_dim
self.tucker = False
self.rs_lora = rs_lora
if self.module_type.startswith("conv"):
self.isconv = True
# For general LoCon
in_dim = org_module.in_channels
k_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
out_dim = org_module.out_channels
use_tucker = use_tucker and any(i != 1 for i in k_size)
self.down_op = self.op
self.up_op = self.op
if use_tucker and any(i != 1 for i in k_size):
self.lora_down = self.module(in_dim, lora_dim, 1, bias=False)
self.lora_mid = self.module(
lora_dim, lora_dim, k_size, stride, padding, bias=False
)
self.tucker = True
else:
self.lora_down = self.module(
in_dim, lora_dim, k_size, stride, padding, bias=False
)
self.lora_up = self.module(lora_dim, out_dim, 1, bias=False)
elif isinstance(org_module, nn.Linear):
self.isconv = False
self.down_op = F.linear
self.up_op = F.linear
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
else:
raise NotImplementedError
self.wd = weight_decompose
self.wd_on_out = wd_on_out
if self.wd:
org_weight = org_module.weight.cpu().clone().float()
self.dora_norm_dims = org_weight.dim() - 1
if self.wd_on_out:
self.dora_scale = nn.Parameter(
torch.norm(
org_weight.reshape(org_weight.shape[0], -1),
dim=1,
keepdim=True,
).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims)
).float()
else:
self.dora_scale = nn.Parameter(
torch.norm(
org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1),
dim=1,
keepdim=True,
)
.reshape(org_weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(1, 0)
).float()
if dropout:
self.dropout = nn.Dropout(dropout)
if self.wd:
log_wd()
else:
self.dropout = nn.Identity()
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
r_factor = lora_dim
if self.rs_lora:
r_factor = math.sqrt(r_factor)
self.scale = alpha / r_factor
self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor)))
if use_scalar:
self.scalar = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("scalar", torch.tensor(1.0), persistent=False)
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
if use_scalar:
torch.nn.init.kaiming_uniform_(self.lora_up.weight, a=math.sqrt(5))
else:
torch.nn.init.constant_(self.lora_up.weight, 0)
if self.tucker:
torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5))
@classmethod
def make_module_from_state_dict(
cls, lora_name, orig_module, up, down, mid, alpha, dora_scale
):
module = cls(
lora_name,
orig_module,
1,
down.size(0),
float(alpha),
use_tucker=mid is not None,
weight_decompose=dora_scale is not None,
)
module.lora_up.weight.data.copy_(up)
module.lora_down.weight.data.copy_(down)
if mid is not None:
module.lora_mid.weight.data.copy_(mid)
if dora_scale is not None:
module.dora_scale.copy_(dora_scale)
return module
def load_weight_hook(self, module: nn.Module, incompatible_keys):
missing_keys = incompatible_keys.missing_keys
for key in missing_keys:
if "scalar" in key:
del missing_keys[missing_keys.index(key)]
if isinstance(self.scalar, nn.Parameter):
self.scalar.data.copy_(torch.ones_like(self.scalar))
elif getattr(self, "scalar", None) is not None:
self.scalar.copy_(torch.ones_like(self.scalar))
else:
self.register_buffer(
"scalar", torch.ones_like(self.scalar), persistent=False
)
def make_weight(self, device=None):
wa = self.lora_up.weight.to(device)
wb = self.lora_down.weight.to(device)
if self.tucker:
t = self.lora_mid.weight
wa = wa.view(wa.size(0), -1).transpose(0, 1)
wb = wb.view(wb.size(0), -1)
weight = rebuild_tucker(t, wa, wb)
else:
weight = wa.view(wa.size(0), -1) @ wb.view(wb.size(0), -1)
weight = weight.view(self.shape)
if self.training and self.rank_dropout:
drop = (torch.rand(weight.size(0), device=device) > self.rank_dropout).to(
weight.dtype
)
drop = drop.view(-1, *[1] * len(weight.shape[1:]))
if self.rank_dropout_scale:
drop /= drop.mean()
weight *= drop
return weight * self.scalar.to(device)
def get_diff_weight(self, multiplier=1, shape=None, device=None):
scale = self.scale * multiplier
diff = self.make_weight(device=device) * scale
if shape is not None:
diff = diff.view(shape)
if device is not None:
diff = diff.to(device)
return diff, None
def get_merged_weight(self, multiplier=1, shape=None, device=None):
diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0]
weight = self.org_weight
if self.wd:
merged = self.apply_weight_decompose(weight + diff, multiplier)
else:
merged = weight + diff * multiplier
return merged, None
def apply_weight_decompose(self, weight, multiplier=1):
weight = weight.to(self.dora_scale.dtype)
if self.wd_on_out:
weight_norm = (
weight.reshape(weight.shape[0], -1)
.norm(dim=1)
.reshape(weight.shape[0], *[1] * self.dora_norm_dims)
) + torch.finfo(weight.dtype).eps
else:
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(0, 1)
) + torch.finfo(weight.dtype).eps
scale = self.dora_scale.to(weight.device) / weight_norm
if multiplier != 1:
scale = multiplier * (scale - 1) + 1
return weight * scale
def custom_state_dict(self):
destination = {}
if self.wd:
destination["dora_scale"] = self.dora_scale
destination["alpha"] = self.alpha
destination["lora_up.weight"] = self.lora_up.weight * self.scalar
destination["lora_down.weight"] = self.lora_down.weight
if self.tucker:
destination["lora_mid.weight"] = self.lora_mid.weight
return destination
@torch.no_grad()
def apply_max_norm(self, max_norm, device=None):
orig_norm = self.make_weight(device).norm() * self.scale
norm = torch.clamp(orig_norm, max_norm / 2)
desired = torch.clamp(norm, max=max_norm)
ratio = desired.cpu() / norm.cpu()
scaled = norm != desired
if scaled:
self.scalar *= ratio
return scaled, orig_norm * ratio
def bypass_forward_diff(self, x, scale=1):
if self.tucker:
mid = self.lora_mid(self.lora_down(x))
else:
mid = self.lora_down(x)
if self.rank_dropout and self.training:
drop = (
torch.rand(self.lora_dim, device=mid.device) > self.rank_dropout
).to(mid.dtype)
if self.rank_dropout_scale:
drop /= drop.mean()
if (dims := len(x.shape)) == 4:
drop = drop.view(1, -1, 1, 1)
else:
drop = drop.view(*[1] * (dims - 1), -1)
mid = mid * drop
return self.dropout(self.lora_up(mid) * self.scalar * self.scale * scale)
def bypass_forward(self, x, scale=1):
return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale)
def forward(self, x):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.org_forward(x)
scale = self.scale
dtype = self.dtype
if not self.bypass_mode:
diff_weight = self.make_weight(x.device).to(dtype) * scale
weight = self.org_module[0].weight.data.to(dtype)
if self.wd:
weight = self.apply_weight_decompose(
weight + diff_weight, self.multiplier
)
else:
weight = weight + diff_weight * self.multiplier
bias = (
None
if self.org_module[0].bias is None
else self.org_module[0].bias.data
)
return self.op(x, weight, bias, **self.kw_dict)
else:
return self.bypass_forward(x, scale=self.multiplier)