-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathProxProp.py
360 lines (299 loc) · 14.1 KB
/
ProxProp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import sys
from collections import OrderedDict
import torch
from torch.autograd import Function
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import numpy
def conjugate_gradient_block(A, B, x0=None, tol=1e-2, maxit=None, eps=1e-6):
"""
Solve the linear system A X = B using conjugate gradient, where
- A is an abstract linear operator implementing an n*n matrix
- B is a matrix right hand size of size n*s
- X0 is an initial guess of size n*s
Essentially, this runs #s classical conjugate gradient algorithms in 'parallel',
and terminates when the worst of the #s residuals is below tol.
"""
X = x0
R = B - A(X)
P = R
Rs_old = torch.norm(R, 2., dim=0) ** 2.
tol_scale = torch.mean(Rs_old)
k = 0
while True:
k += 1
AP = A(P)
alpha = Rs_old / (torch.sum(P * AP, dim=0) + eps)
X += P * alpha
if k == maxit:
break
R -= AP * alpha
Rs_new = torch.norm(R, 2., dim=0) ** 2.
res = torch.max(Rs_new)
if res < (tol ** 2.) * tol_scale:
break
P = R + P * (Rs_new / (Rs_old + eps))
Rs_old = Rs_new
return X, k
def optimization_step(A, Y, Z, beta, apply_cg, mode='prox_cg1'):
"""
Optimization step for several different modes:
- 'prox_exact' takes an exact proximal step
- 'prox_cgN' takes an approximate proximal step, generated by N conjugate gradient steps
- 'gradient' performs a gradient step; this method recovers classical SGD
Taking a proximal step amounts to solving:
argmin_{X} 1/2 ||AX - Y||^2 + beta/2 ||X - Z||^2,
i.e. solving a linear system (exactly or approximately).
The method argument 'apply_cg' is the one taken from the respective ProxProp module.
"""
apply_A = partial(apply_cg, A)
if mode == 'prox_exact':
AA = A.t().mm(A)
I = torch.eye(A.size(1)).type_as(A)
A_tilde = AA + beta * I
b_tilde = A.t().mm(Y) + beta*Z
X, _ = torch.gesv(b_tilde, A_tilde)
elif mode[:7] == 'prox_cg':
num_prox_steps = int(mode[7:])
apply_A_tilde = lambda x : apply_A(x, 3) + beta*x
b_tilde = apply_A(Y,2) + beta * Z
res = conjugate_gradient_block(apply_A_tilde, b_tilde, x0=Z, maxit=num_prox_steps)
X = res[0]
elif mode == 'gradient':
X = Z - (apply_A(apply_A(Z,1) - Y,2))
else:
raise ValueError('The optimization mode "{}" you have specified is not valid.'.format(mode))
return X
class ForwardBackwardFunctional(Function):
"""
Generic forward/backward functional that is used in any ProxProp module.
"""
@staticmethod
def forward(ctx, *args):
ctx.optimization_layer = args[-1]
output = ctx.optimization_layer.apply_forward(args[0])
ctx.save_for_backward(*args[:-1], output)
return output
@staticmethod
def backward(ctx, grad_z):
input = ctx.saved_variables[0]
params = list(ctx.saved_variables[1:-1])
output = ctx.saved_variables[-1]
grad_input = None
grad_params = [None] * len(params)
layer = ctx.optimization_layer
# explicit gradient step on z
z_updated = output - grad_z
# prox step or gradient step on the network parameters
if layer.optimization_mode == 'prox_exact':
A, Y, Z = layer.to_exact_solve_shape(input, z_updated, *params)
else:
A = input.detach()
Y = z_updated.detach()
Z = layer.to_cg_shape(params).detach()
X_tensor = optimization_step(A, Y, Z, 1./layer.tau_prox, layer.apply_cg, mode=layer.optimization_mode)
if 'prox_exact' == layer.optimization_mode:
params_udpated = list(layer.from_exact_solve_shape(X_tensor).values())
else:
params_udpated = list(layer.from_cg_shape(X_tensor).values())
# write difference in grad fields
grad_params = [x[0] - x[1] for x in zip(params,params_udpated)]
# explicit gradient step on a
input.requires_grad_()
with torch.enable_grad():
out_temp = ctx.optimization_layer.apply_forward(input)
grad_temp = torch.autograd.grad(out_temp, input, grad_z)
grad_input = grad_temp[0]
return tuple([grad_input] + grad_params + [None])
def proxprop_module_generator(BaseModule, to_cg_shape=None, from_cg_shape=None, to_exact_solve_shape=None, from_exact_solve_shape=None):
"""
Creates a ProxProp module for any given linear BaseModule.
The methods to_cg_shape and from_cg_shape convert to the conjugate gradient format and back.
By default, they flatten the all variables to a vector. One may want to provide specific functions
to use a special structure of the layer for better (block) conjugate gradient performance.
The consistency of the implementation is automatically checked during initialization.
The methods to_exact_solve_shape and from_exact_solve_shape cannot be provided by default, but are needed to take an exact proximal step.
You can test your implementation for shape consistency by calling test_exact_solve_reshaping on a test input batch.
The generated ProxProp module uses the BaseModule's forward method and can therefore leverage existing, efficient implementations.
"""
class ProxPropModule(BaseModule):
def __init__(self, *args, **kwargs):
if 'tau_prox' in kwargs:
tau_prox_arg = [kwargs['tau_prox']]
del kwargs['tau_prox']
else:
tau_prox_arg = [1.]
if 'optimization_mode' in kwargs:
self.optimization_mode = kwargs['optimization_mode']
del kwargs['optimization_mode']
else:
self.optimization_mode = 'prox_cg1'
super().__init__(*args, **kwargs)
self.register_buffer('tau_prox', torch.Tensor(tau_prox_arg))
self.forward_backward_functional = ForwardBackwardFunctional
self._test_cg_reshaping()
def _compare_two_params_dicts(self, d1, d2):
assert len(d1.items()) == len(d2.items())
for name, p1 in d1.items():
p2 = d2[name]
p1_np = p1.detach().cpu().numpy()
p2_np = p2.detach().cpu().numpy()
assert numpy.allclose(p1_np, p2_np)
def _test_cg_reshaping(self):
named_params_check = self.from_cg_shape(self.to_cg_shape([p for p in self.parameters()]))
module_named_params = dict(self.named_parameters())
self._compare_two_params_dicts(module_named_params, named_params_check)
def test_exact_solve_reshaping(self, x):
"""
Checks an implementation of to/from_exact_solve_shape and can be called with an
input variable of type torch.autograd.Variable of the layer's input shape (including batch dimension).
"""
try:
y = self.apply_forward(x)
params = list(self.parameters())
A, Y, Z = self.to_exact_solve_shape(x, y, *params)
named_params_check = self.from_exact_solve_shape(Z)
module_named_params = dict(self.named_parameters())
self._compare_two_params_dicts(module_named_params, named_params_check)
return True
except NotImplementedError:
print('At least one of the exact solve reshaping methods is not implemented.')
return False
except Exception as e:
print(e)
return False
def forward(self, input):
args = [input] + list(self.parameters()) + [self]
return self.forward_backward_functional.apply(*args)
def apply_forward(self, x):
return super().forward(x)
def apply_adjoint(self, forward_out, x):
forward_out.backward(x)
def to_cg_shape(self, params_list):
"""
Default implementation. Flattens all parameters to a vector.
Expects the module's parameter data tensors in a list as provided by [p.data for p in self.parameters()].
Returns a tensor with dimensions expected by the conjugate gradient solver.
"""
return torch.cat([p.view(-1) for p in params_list])
def from_cg_shape(self, x):
"""
Default implementation. Assumes flattened parameters and reshapes to module parameter shape.
Expects a tensor with shape used by the conjugate gradient solver.
Returns an OrderedDict containing the the module's parameters in the in their native shape and as
an nn.Parameter() object.
"""
prev_var_counter = 0
params_cg = OrderedDict()
for name, p in self.named_parameters():
n = p.numel()
var_counter = prev_var_counter + n
p_cg = x[prev_var_counter:var_counter].view(p.size())
params_cg[name] = torch.nn.Parameter(p_cg)
prev_var_counter = var_counter
return params_cg
def to_exact_solve_shape(self, x, y, *params):
"""
Needs to be implemented for the exact solve of the proximal step.
Prepares the tensors to solve the proximal step in the form argmin_{X} 1/2 ||AX - Y||^2 + beta/2 ||X - Z||^2,
where X are the updated parameters to solve for.
Expects the current input data x, the already updated non-linear activations from the above layer y and an
argument list of module parameters *params.
Returns (A, Y, Z), where Y are the already updated non-linear activations from the layer above in the right shape.
"""
raise NotImplementedError
def from_exact_solve_shape(self, exact_solve_out):
"""
Needs to be implemented for the exact solve of the proximal step.
Reshapes the output of the exact solve method to the native parameter shape for the parameters
as nn.Parameter() objects stored in an OrderedDict.
"""
raise NotImplementedError
def apply_cg(self, A, x, mode):
"""
Abstract linear operator used for the conjugate gradient solver.
Returns Ax for mode=1, A^Tx for mode=2 and A^T(Ax) for mode=3.
This method uses the efficient forward implementation of the BaseModule.
The tradeoff for this generic implementation is that we have to assign temporary values to the module's parameters.
"""
if mode == 1:
params_backup = self._parameters
self._parameters = self.from_cg_shape(x)
with torch.enable_grad():
output = self.apply_forward(A)
self._parameters = params_backup
return output
elif mode == 2:
self.zero_grad()
with torch.enable_grad():
self.apply_adjoint(self.apply_forward(A), x)
result = self.to_cg_shape([p.grad for p in self.parameters()])
self.zero_grad()
return result
elif mode == 3:
params_backup = self._parameters
self._parameters = self.from_cg_shape(x)
self.zero_grad()
with torch.enable_grad():
forward_out = self.apply_forward(A.requires_grad_())
self.apply_adjoint(forward_out, forward_out)
result = self.to_cg_shape([p.grad for p in self.parameters()])
self.zero_grad()
self._parameters = params_backup
return result
else:
raise ValueError('Mode {} is not valid. Provide 1 for Ax and 2 for A^Tx.'.format(mode))
ProxPropModule.__name__ = 'ProxProp_{}'.format(BaseModule.__name__)
proxprop_module = ProxPropModule
if to_cg_shape is not None:
setattr(proxprop_module, 'to_cg_shape', to_cg_shape)
if from_cg_shape is not None:
setattr(proxprop_module, 'from_cg_shape', from_cg_shape)
if to_exact_solve_shape is not None:
setattr(proxprop_module, 'to_exact_solve_shape', to_exact_solve_shape)
if from_exact_solve_shape is not None:
setattr(proxprop_module, 'from_exact_solve_shape', from_exact_solve_shape)
return proxprop_module
# generate ProxProp Conv2d module
ProxPropConv2d = proxprop_module_generator(nn.Conv2d)
# generate ProxProp Linear module
def linear_to_cg_shape(self, params_list):
"""
Reshape to use the matrix version of the conjugate gradient solver.
"""
W = params_list[0]
b = params_list[1]
return torch.cat((W.t(),torch.unsqueeze(b,0).type_as(W)),0)
def linear_from_cg_shape(self, x):
"""
Reshape the output of the matrix version of the conjugate gradient solver.
"""
params_cg = OrderedDict()
if self.bias is None:
params_cg['weight'] = torch.nn.Parameter(x)
else:
W, b = torch.split(x, self.weight.size(1), dim=0)
params_cg['weight'] = torch.nn.Parameter(W.t())
params_cg['bias'] = torch.nn.Parameter(torch.squeeze(b))
return params_cg
def linear_to_exact_solve_shape(self, x, z_updated, W, b):
"""
Suitable reshape for the exact solve method.
"""
Z = torch.cat((W.t(),torch.unsqueeze(b,0).type_as(W)),0)
A = torch.cat((x, torch.ones(x.size(0),1).type_as(x)),1)
return A, z_updated, Z
def linear_from_exact_solve_shape(self, exact_solve_out):
"""
Reshape from exact solve method.
"""
params_out = OrderedDict()
if self.bias is None:
params_out['weight'] = nn.Parameter(exact_solve_out.t())
else:
W, b = torch.split(exact_solve_out, self.weight.size(1), dim=0)
params_out['weight'] = nn.Parameter(W.t())
params_out['bias'] = nn.Parameter(b.squeeze())
return params_out
ProxPropLinear = proxprop_module_generator(nn.Linear, to_cg_shape=linear_to_cg_shape, from_cg_shape=linear_from_cg_shape, to_exact_solve_shape=linear_to_exact_solve_shape, from_exact_solve_shape=linear_from_exact_solve_shape)