-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtent.py
352 lines (282 loc) · 11.6 KB
/
tent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
from copy import deepcopy
import torch
import torch.nn as nn
import torch.jit
from pdb import set_trace as st
class Tent(nn.Module):
"""Tent adapts a model by entropy minimization during testing.
Once tented, a model adapts itself by updating on every forward.
"""
def __init__(self, model, optimizer, steps=1, episodic=False, params=None, use_gram=False, g_train=None, classwise=False, meta_train=False):
super().__init__()
self.model = model
self.optimizer = optimizer
self.steps = steps
assert steps > 0, "tent requires >= 1 step(s) to forward and update"
self.episodic = episodic
self.params = params
self.use_gram = use_gram
self.g_train = g_train
self.classwise = classwise
self.meta_train = meta_train
# note: if the model is never reset, like for continual adaptation,
# then skipping the state copy would save memory
self.model_state, self.optimizer_state = \
copy_model_and_optimizer(self.model, self.optimizer)
def forward(self, x, y=None, loss_weight=None):
if self.episodic:
self.reset()
for _ in range(self.steps):
if self.meta_train:
loss_weight = forward_and_adapt_gram_meta_train(x, y, loss_weight, [0.001, 10, 100], self.model, self.optimizer, self.params, self.g_train)
return loss_weight
elif not self.use_gram:
outputs = forward_and_adapt(x, self.model, self.optimizer)
elif not self.classwise:
outputs = forward_and_adapt_gram(x, self.model, self.optimizer, self.g_train)
else:
outputs = forward_and_adapt_gram_classwise(x, self.model, self.optimizer, self.g_train)
self.model.eval()
outputs = self.model(x)
return outputs
def reset(self):
if self.model_state is None or self.optimizer_state is None:
raise Exception("cannot reset without saved model/optimizer state")
load_model_and_optimizer(self.model, self.optimizer,
self.model_state, self.optimizer_state)
@torch.jit.script
def softmax_entropy(x: torch.Tensor, T1: float=1, T2: float=1) -> torch.Tensor:
"""Entropy of softmax distribution from logits."""
x1 = x / T1
x2 = x / T2
return -(x1.softmax(1) * x2.log_softmax(1)).sum(1)
@torch.enable_grad()
def compute_gram_matrix(input):
a, b, c = input.size() # a=batch size(=1)
# b=number of feature maps
# (c,d)=dimensions of a f. map (N=c*d)
features = input.view(a, b * c) # resise F_XL into \hat F_XL
G = torch.mm(features, features.t()) # compute the gram product
return G
@torch.enable_grad() # ensure grads in possible no grad context for testing
def forward_and_adapt_gram_classwise(x, model, optimizer, g_train):
"""Forward and adapt model on batch of data.
Measure entropy of the model prediction, take gradients, and update params.
"""
# forward
outputs, outputs_fp = model(x, feature_maps=True)
class_preds = outputs.argmax(1)
# adapt
loss = 0
loss_fn = nn.MSELoss()
loss_weight = [0, 1, 0]
loss_list = [0,0,0]
for i in range(outputs.shape[0]):
class_id = class_preds[i].item()
loss += loss_weight[0] * loss_fn(compute_gram_matrix(outputs_fp[0][i]), g_train[0][class_id])
loss += loss_weight[1] * loss_fn(compute_gram_matrix(outputs_fp[1][i]), g_train[1][class_id])
loss += loss_weight[2] * loss_fn(compute_gram_matrix(outputs_fp[2][i]), g_train[2][class_id])
loss /= outputs.shape[0]
# for i in range(3):
# loss_list[i] += loss_fn(g_test[i], g_train[i])
# loss += loss_weight[i] * loss_list[i]
loss.backward()
optimizer.step()
optimizer.zero_grad()
return outputs
def loss_wishart(g_test, g_train):
loss=0
d_list = [1024,256,64]
for idx in range(3):
q = g_test[idx].shape[0]
d = d_list[idx]
t1 = torch.mm(g_test[idx],torch.inverse(g_train[idx]))
t2 = torch.mm(g_train[idx],torch.inverse(g_test[idx]))
t3 = torch.trace(t1+t2)
t3 = t3 * q/4 - q*d/2
if idx==1:
t3=t3*4
if idx==2:
t3=t3*0
loss+=t3
return loss
@torch.enable_grad() # ensure grads in possible no grad context for testing
def forward_and_adapt_gram(x, model, optimizer, g_train):
"""Forward and adapt model on batch of data.
Measure entropy of the model prediction, take gradients, and update params.
"""
# forward
outputs, outputs_fp = model(x, feature_maps=True)
m = nn.Softmax(dim=1)
outputs_softmax = m(outputs)
class_prob = outputs_softmax.max(1)[0]
g_test = [torch.zeros(64, 64).to("cuda"),
torch.zeros(128, 128).to("cuda"),
torch.zeros(256, 256).to("cuda")]
count = 0
for i in range(outputs.shape[0]):
if class_prob[i] >= 0.0:
count += 1
g_test[0] += compute_gram_matrix(outputs_fp[0][i])
g_test[1] += compute_gram_matrix(outputs_fp[1][i])
g_test[2] += compute_gram_matrix(outputs_fp[2][i])
for i in range(3):
# g_test[i] /= outputs.shape[0]
g_test[i] /= count
# adapt
loss = 0
loss_fn = nn.MSELoss()
loss_weight = [-2, 10, 1]
loss_list = [0,0,0]
for i in range(3):
loss_list[i] += loss_fn(g_test[i], g_train[i])
loss += loss_weight[i] * loss_list[i]
# loss = loss_wishart(g_test, g_train)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return outputs
@torch.enable_grad() # ensure grads in possible no grad context for testing
def forward_and_adapt_gram_meta_train(x, y, loss_weight, lr, model, optimizer, params, g_train):
"""Forward and adapt model on batch of data.
Measure entropy of the model prediction, take gradients, and update params.
"""
# forward
outputs, outputs_fp = model(x, feature_maps=True)
m = nn.Softmax(dim=1)
outputs_softmax = m(outputs)
class_prob = outputs_softmax.max(1)[0]
g_test = [torch.zeros(64, 64).to("cuda"),
torch.zeros(128, 128).to("cuda"),
torch.zeros(256, 256).to("cuda")]
count = 0
for i in range(outputs.shape[0]):
if class_prob[i] >= 0.0:
count += 1
g_test[0] += compute_gram_matrix(outputs_fp[0][i])
g_test[1] += compute_gram_matrix(outputs_fp[1][i])
g_test[2] += compute_gram_matrix(outputs_fp[2][i])
for i in range(3):
# g_test[i] /= outputs.shape[0]
g_test[i] /= count
# adapt
loss = 0
loss_fn = nn.MSELoss()
loss_list = [0,0,0]
for i in range(3):
loss_list[i] += loss_fn(g_test[i], g_train[i])
loss += loss_weight[i] * loss_list[i]
grads = [0, 0, 0]
grads[0] = torch.autograd.grad(loss_list[0], params, retain_graph=True, allow_unused=True)
grads[1] = torch.autograd.grad(loss_list[1], params, retain_graph=True, allow_unused=True)
grads[2] = torch.autograd.grad(loss_list[2], params, retain_graph=True, allow_unused=True)
loss.backward()
optimizer.step()
optimizer.zero_grad()
CE_Loss = nn.CrossEntropyLoss()
outputs_new = model(x)
meta_loss = CE_Loss(outputs_new, y)
grads_new = torch.autograd.grad(meta_loss, params, allow_unused=True)
values = [0, 0, 0]
for i in range(3):
for (a, b) in zip(grads[i], grads_new):
if a is not None:
# a = (a - torch.mean(a))/torch.var(a)
# st()
values[i] += torch.dot(a, b)
for i in range(3):
loss_weight[i] += lr[i] * values[i]
return loss_weight
@torch.enable_grad() # ensure grads in possible no grad context for testing
def forward_and_adapt(x, model, optimizer):
"""Forward and adapt model on batch of data.
Measure entropy of the model prediction, take gradients, and update params.
"""
# forward
outputs = model(x)
# adapt
loss = softmax_entropy(outputs).mean(0)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return outputs
def collect_params(model):
"""Collect the affine scale + shift parameters from batch norms.
Walk the model's modules and collect all batch normalization parameters.
Return the parameters and their names.
Note: other choices of parameterization are possible!
"""
params = []
names = []
for nm, m in model.named_modules():
if isinstance(m, nn.BatchNorm2d):
for np, p in m.named_parameters():
if np in ['weight', 'bias']: # weight is scale, bias is shift
params.append(p)
names.append(f"{nm}.{np}")
return params, names
def collect_params_full(model):
"""Collect the affine scale + shift parameters from batch norms.
Walk the model's modules and collect all batch normalization parameters.
Return the parameters and their names.
Note: other choices of parameterization are possible!
"""
params = []
names = []
for nm, m in model.named_modules():
# if not isinstance(m, nn.BatchNorm2d):
if True:
for np, p in m.named_parameters():
if np in ['weight', 'bias']: # weight is scale, bias is shift
params.append(p)
names.append(f"{nm}.{np}")
return params, names
def copy_model_and_optimizer(model, optimizer):
"""Copy the model and optimizer states for resetting after adaptation."""
model_state = deepcopy(model.state_dict())
optimizer_state = deepcopy(optimizer.state_dict())
return model_state, optimizer_state
def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
"""Restore the model and optimizer states from copies."""
model.load_state_dict(model_state, strict=True)
optimizer.load_state_dict(optimizer_state)
def configure_model(model):
"""Configure model for use with tent."""
# train mode, because tent optimizes the model to minimize entropy
model.train()
# disable grad, to (re-)enable only what tent updates
model.requires_grad_(False)
# configure norm for tent updates: enable grad + force batch statisics
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.requires_grad_(True)
# force use of batch stats in train and eval modes
m.track_running_stats = False
m.running_mean = None
m.running_var = None
return model
def configure_model_eval(model):
"""Configure model for use with tent."""
# train mode, because tent optimizes the model to minimize entropy
model.eval()
# disable grad, to (re-)enable only what tent updates
model.requires_grad_(False)
# configure norm for tent updates: enable grad + force batch statisics
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.requires_grad_(True)
# force use of batch stats in train and eval modes
return model
def check_model(model):
"""Check model for compatability with tent."""
is_training = model.training
assert is_training, "tent needs train mode: call model.train()"
param_grads = [p.requires_grad for p in model.parameters()]
has_any_params = any(param_grads)
has_all_params = all(param_grads)
assert has_any_params, "tent needs params to update: " \
"check which require grad"
assert not has_all_params, "tent should not update all params: " \
"check which require grad"
has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()])
assert has_bn, "tent needs normalization for its optimization"