Skip to content

Commit 60b046e

Browse files
committed
Alternative init method of Tensor for gvp's
1 parent 7158045 commit 60b046e

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

src/tensortrax/_tensor.py

+30-12
Original file line numberDiff line numberDiff line change
@@ -242,20 +242,38 @@ def init(self, gradient=False, hessian=False, sym=False, δx=None, Δx=None):
242242
"""Re-Initialize tensor with dual values to keep track of the
243243
hessian and/or the gradient."""
244244

245-
if gradient and not hessian:
246-
# add additional element-wise acting axes for dual values
245+
# if gradient and not hessian:
246+
# # add additional element-wise acting axes for dual values
247+
# self._add(ndual=len(self.shape))
248+
249+
# # create the dual values
250+
# if δx is None or isinstance(δx, bool):
251+
# shape = (*self.shape, *self.shape)
252+
# if len(shape) == 0:
253+
# shape = (1,)
254+
# if δx is False:
255+
# δx = np.zeros(self.size**2).reshape(shape)
256+
# else:
257+
# δx = np.eye(self.size).reshape(shape)
258+
# Δx = δx.copy()
259+
260+
if gradient:
261+
# add additional trailing axes for dual values
247262
self._add(ndual=len(self.shape))
248263

249264
# create the dual values
265+
ones = np.ones(len(self.shape), dtype=int)
266+
250267
if δx is None or isinstance(δx, bool):
251-
shape = (*self.shape, *self.shape)
268+
shape = (*self.shape, *ones)
252269
if len(shape) == 0:
253270
shape = (1,)
254271
if δx is False:
255-
δx = np.zeros(self.size**2).reshape(shape)
272+
δx = np.zeros(self.size).reshape(shape)
256273
else:
257274
δx = np.eye(self.size).reshape(shape)
258-
Δx = δx.copy()
275+
else:
276+
δx = δx.reshape(*self.shape, *self.trax)
259277

260278
elif hessian:
261279
# add additional trailing axes for dual values
@@ -454,7 +472,7 @@ def ravel(self, order="C"):
454472

455473
def reshape(self, *shape, order="C"):
456474
"Gives a new shape to an array without changing its data."
457-
return reshape(self, newshape=shape, order=order)
475+
return reshape(self, shape, order=order)
458476

459477
def squeeze(self, axis=None):
460478
"Remove axes of length one."
@@ -572,21 +590,21 @@ def squeeze(A, axis=None):
572590
return np.squeeze(A, axis=axis)
573591

574592

575-
def reshape(A, newshape, order="C"):
593+
def reshape(A, shape, order="C"):
576594
"Gives a new shape to an array without changing its data."
577595
if isinstance(A, Tensor):
578596
δtrax = δ(A).shape[len(A.shape) :]
579597
Δtrax = Δ(A).shape[len(A.shape) :]
580598
Δδtrax = Δδ(A).shape[len(A.shape) :]
581599
return Tensor(
582-
x=f(A).reshape(*newshape, *A.trax, order=order),
583-
δx=δ(A).reshape(*newshape, *δtrax, order=order),
584-
Δx=Δ(A).reshape(*newshape, *Δtrax, order=order),
585-
Δδx=Δδ(A).reshape(*newshape, *Δδtrax, order=order),
600+
x=f(A).reshape(*shape, *A.trax, order=order),
601+
δx=δ(A).reshape(*shape, *δtrax, order=order),
602+
Δx=Δ(A).reshape(*shape, *Δtrax, order=order),
603+
Δδx=Δδ(A).reshape(*shape, *Δδtrax, order=order),
586604
ntrax=A.ntrax,
587605
)
588606
else:
589-
return np.reshape(A, newshape=newshape, order=order)
607+
return np.reshape(A, shape, order=order)
590608

591609

592610
def einsum4(subscripts, *operands):

tests/test_hessian_vector_product.py

+12
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,17 @@ def test_hvp():
5050
assert not np.any(np.isnan(hvp))
5151

5252

53+
def test_gvp():
54+
F = np.diag([2, 1.15, 1.15])
55+
F[:, 0] += np.array([0, -0.15, -0.15])
56+
δF = np.vstack([np.zeros(3), np.zeros(3), np.array([-0.04, 0.04, 0.04])])
57+
58+
δψ = tr.gradient_vector_product(neo_hooke)(F, δx=δF)
59+
δψ_reference = tm.special.ddot(tr.gradient(neo_hooke)(F), δF)
60+
61+
assert np.isclose(δψ, δψ_reference)
62+
63+
5364
if __name__ == "__main__":
5465
test_hvp()
66+
test_gvp()

0 commit comments

Comments
 (0)