Skip to content

Commit c3b9089

Browse files
committed
solveJ dense
1 parent 944c5ae commit c3b9089

File tree

3 files changed

+30
-10
lines changed

3 files changed

+30
-10
lines changed

nngeometry/object/pspace.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,14 @@ def solveJ(self, J, regul=1e-8):
164164
"""
165165
solves J = AX in X
166166
"""
167+
J_torch = J.to_torch()
168+
sJ = J_torch.size()
167169
inv_v = torch.linalg.solve(
168170
self.data + regul * torch.eye(self.size(0), device=self.data.device),
169-
J.to_torch()[0].t(),
171+
J_torch.view(-1, sJ[-1]),
172+
left=False,
170173
)
171-
return inv_v
174+
return PFMapDense(generator=self.generator, data=inv_v.reshape(*sJ))
172175

173176
def inverse(self, regul=1e-8):
174177
inv_tensor = torch.inverse(
@@ -667,7 +670,9 @@ def to_torch(self, split_weight_bias=True):
667670
"""
668671
_, diags = self.data
669672
s = self.generator.layer_collection.numel()
670-
M = torch.zeros((s, s), device=self.generator.get_device(), dtype=self.generator.get_dtype())
673+
M = torch.zeros(
674+
(s, s), device=self.generator.get_device(), dtype=self.generator.get_dtype()
675+
)
671676
KFE_layers = self.get_KFE(split_weight_bias=split_weight_bias)
672677
for layer_id, _ in self.generator.layer_collection.layers.items():
673678
diag = diags[layer_id]

tests/test_torch_hooks.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -341,19 +341,33 @@ def test_jacobian_pdense():
341341
# Test solve
342342
# NB: regul is high since the matrix is not full rank
343343
regul = 1e-3
344-
Mv_regul = torch.mv(
345-
PMat_dense.to_torch()
346-
+ regul * torch.eye(PMat_dense.size(0), device=device),
347-
dw.to_torch(),
344+
Mv_torch = torch.mv(PMat_dense.to_torch(), dw.to_torch())
345+
Mv_regul = PVector(
346+
layer_collection=lc, vector_repr=Mv_torch + regul * dw.to_torch()
348347
)
349-
Mv_regul = PVector(layer_collection=lc, vector_repr=Mv_regul)
350-
dw_using_inv = PMat_dense.solve(Mv_regul, regul=regul)
348+
dw_solve = PMat_dense.solve(Mv_regul, regul=regul)
351349
check_tensors(
352350
dw.to_torch(),
353-
dw_using_inv.to_torch(),
351+
dw_solve.to_torch(),
354352
eps=5e-3,
355353
)
356354

355+
# Test solve with jacobian
356+
# TODO improve
357+
c = 1.678
358+
stacked_mv = torch.stack((Mv_torch, c * Mv_torch)).unsqueeze(0)
359+
stacked_v = torch.stack((dw.to_torch(), c * dw.to_torch())).unsqueeze(0)
360+
jaco = PFMapDense(
361+
generator=generator,
362+
data=stacked_mv + regul * stacked_v,
363+
)
364+
J_back = PMat_dense.solveJ(jaco, regul=regul)
365+
366+
check_tensors(
367+
stacked_v,
368+
J_back.to_torch(),
369+
)
370+
357371
# Test inv
358372
PMat_inv = PMat_dense.inverse(regul=regul)
359373
check_tensors(

tests/test_torch_hooks_ekfac.py

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def test_pspace_ekfac_vs_direct():
108108
check_tensors(v.to_torch(), v_back.to_torch())
109109

110110
# Test solve with jacobian
111+
# TODO improve
111112
c = 1.678
112113
stacked_mv = torch.stack(
113114
(mv_ekfac.to_torch(), c * mv_ekfac.to_torch())

0 commit comments

Comments
 (0)