Skip to content

Commit 3f80422

Browse files
authored
Fix Tensor.init(gradient=True, δx=δx) (#123)
* Alternative `init` method of Tensor for gvp's * Update _tensor.py * Update _tensor.py
1 parent 7158045 commit 3f80422

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

src/tensortrax/_tensor.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,11 @@ def init(self, gradient=False, hessian=False, sym=False, δx=None, Δx=None):
255255
δx = np.zeros(self.size**2).reshape(shape)
256256
else:
257257
δx = np.eye(self.size).reshape(shape)
258-
Δx = δx.copy()
258+
259+
else:
260+
δx = δx.reshape(*self.shape, *self.trax)
261+
262+
Δx = δx.copy()
259263

260264
elif hessian:
261265
# add additional trailing axes for dual values
@@ -454,7 +458,7 @@ def ravel(self, order="C"):
454458

455459
def reshape(self, *shape, order="C"):
456460
"Gives a new shape to an array without changing its data."
457-
return reshape(self, newshape=shape, order=order)
461+
return reshape(self, shape, order=order)
458462

459463
def squeeze(self, axis=None):
460464
"Remove axes of length one."
@@ -572,21 +576,21 @@ def squeeze(A, axis=None):
572576
return np.squeeze(A, axis=axis)
573577

574578

575-
def reshape(A, newshape, order="C"):
579+
def reshape(A, shape, order="C"):
576580
"Gives a new shape to an array without changing its data."
577581
if isinstance(A, Tensor):
578582
δtrax = δ(A).shape[len(A.shape) :]
579583
Δtrax = Δ(A).shape[len(A.shape) :]
580584
Δδtrax = Δδ(A).shape[len(A.shape) :]
581585
return Tensor(
582-
x=f(A).reshape(*newshape, *A.trax, order=order),
583-
δx=δ(A).reshape(*newshape, *δtrax, order=order),
584-
Δx=Δ(A).reshape(*newshape, *Δtrax, order=order),
585-
Δδx=Δδ(A).reshape(*newshape, *Δδtrax, order=order),
586+
x=f(A).reshape(*shape, *A.trax, order=order),
587+
δx=δ(A).reshape(*shape, *δtrax, order=order),
588+
Δx=Δ(A).reshape(*shape, *Δtrax, order=order),
589+
Δδx=Δδ(A).reshape(*shape, *Δδtrax, order=order),
586590
ntrax=A.ntrax,
587591
)
588592
else:
589-
return np.reshape(A, newshape=newshape, order=order)
593+
return np.reshape(A, shape, order=order)
590594

591595

592596
def einsum4(subscripts, *operands):

tests/test_hessian_vector_product.py

+12
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,17 @@ def test_hvp():
5050
assert not np.any(np.isnan(hvp))
5151

5252

53+
def test_gvp():
54+
F = np.diag([2, 1.15, 1.15])
55+
F[:, 0] += np.array([0, -0.15, -0.15])
56+
δF = np.vstack([np.zeros(3), np.zeros(3), np.array([-0.04, 0.04, 0.04])])
57+
58+
δψ = tr.gradient_vector_product(neo_hooke)(F, δx=δF)
59+
δψ_reference = tm.special.ddot(tr.gradient(neo_hooke)(F), δF)
60+
61+
assert np.isclose(δψ, δψ_reference)
62+
63+
5364
if __name__ == "__main__":
5465
test_hvp()
66+
test_gvp()

0 commit comments

Comments
 (0)