@@ -242,20 +242,38 @@ def init(self, gradient=False, hessian=False, sym=False, δx=None, Δx=None):
242
242
"""Re-Initialize tensor with dual values to keep track of the
243
243
hessian and/or the gradient."""
244
244
245
- if gradient and not hessian :
246
- # add additional element-wise acting axes for dual values
245
+ # if gradient and not hessian:
246
+ # # add additional element-wise acting axes for dual values
247
+ # self._add(ndual=len(self.shape))
248
+
249
+ # # create the dual values
250
+ # if δx is None or isinstance(δx, bool):
251
+ # shape = (*self.shape, *self.shape)
252
+ # if len(shape) == 0:
253
+ # shape = (1,)
254
+ # if δx is False:
255
+ # δx = np.zeros(self.size**2).reshape(shape)
256
+ # else:
257
+ # δx = np.eye(self.size).reshape(shape)
258
+ # Δx = δx.copy()
259
+
260
+ if gradient :
261
+ # add additional trailing axes for dual values
247
262
self ._add (ndual = len (self .shape ))
248
263
249
264
# create the dual values
265
+ ones = np .ones (len (self .shape ), dtype = int )
266
+
250
267
if δx is None or isinstance (δx , bool ):
251
- shape = (* self .shape , * self . shape )
268
+ shape = (* self .shape , * ones )
252
269
if len (shape ) == 0 :
253
270
shape = (1 ,)
254
271
if δx is False :
255
- δx = np .zeros (self .size ** 2 ).reshape (shape )
272
+ δx = np .zeros (self .size ).reshape (shape )
256
273
else :
257
274
δx = np .eye (self .size ).reshape (shape )
258
- Δx = δx .copy ()
275
+ else :
276
+ δx = δx .reshape (* self .shape , * self .trax )
259
277
260
278
elif hessian :
261
279
# add additional trailing axes for dual values
@@ -454,7 +472,7 @@ def ravel(self, order="C"):
454
472
455
473
def reshape (self , * shape , order = "C" ):
456
474
"Gives a new shape to an array without changing its data."
457
- return reshape (self , newshape = shape , order = order )
475
+ return reshape (self , shape , order = order )
458
476
459
477
def squeeze (self , axis = None ):
460
478
"Remove axes of length one."
@@ -572,21 +590,21 @@ def squeeze(A, axis=None):
572
590
return np .squeeze (A , axis = axis )
573
591
574
592
575
- def reshape (A , newshape , order = "C" ):
593
+ def reshape (A , shape , order = "C" ):
576
594
"Gives a new shape to an array without changing its data."
577
595
if isinstance (A , Tensor ):
578
596
δtrax = δ (A ).shape [len (A .shape ) :]
579
597
Δtrax = Δ (A ).shape [len (A .shape ) :]
580
598
Δδtrax = Δδ (A ).shape [len (A .shape ) :]
581
599
return Tensor (
582
- x = f (A ).reshape (* newshape , * A .trax , order = order ),
583
- δx = δ (A ).reshape (* newshape , * δtrax , order = order ),
584
- Δx = Δ (A ).reshape (* newshape , * Δtrax , order = order ),
585
- Δδx = Δδ (A ).reshape (* newshape , * Δδtrax , order = order ),
600
+ x = f (A ).reshape (* shape , * A .trax , order = order ),
601
+ δx = δ (A ).reshape (* shape , * δtrax , order = order ),
602
+ Δx = Δ (A ).reshape (* shape , * Δtrax , order = order ),
603
+ Δδx = Δδ (A ).reshape (* shape , * Δδtrax , order = order ),
586
604
ntrax = A .ntrax ,
587
605
)
588
606
else :
589
- return np .reshape (A , newshape = newshape , order = order )
607
+ return np .reshape (A , shape , order = order )
590
608
591
609
592
610
def einsum4 (subscripts , * operands ):
0 commit comments