@@ -255,7 +255,11 @@ def init(self, gradient=False, hessian=False, sym=False, δx=None, Δx=None):
255
255
δx = np .zeros (self .size ** 2 ).reshape (shape )
256
256
else :
257
257
δx = np .eye (self .size ).reshape (shape )
258
- Δx = δx .copy ()
258
+
259
+ else :
260
+ δx = δx .reshape (* self .shape , * self .trax )
261
+
262
+ Δx = δx .copy ()
259
263
260
264
elif hessian :
261
265
# add additional trailing axes for dual values
@@ -454,7 +458,7 @@ def ravel(self, order="C"):
454
458
455
459
def reshape (self , * shape , order = "C" ):
456
460
"Gives a new shape to an array without changing its data."
457
- return reshape (self , newshape = shape , order = order )
461
+ return reshape (self , shape , order = order )
458
462
459
463
def squeeze (self , axis = None ):
460
464
"Remove axes of length one."
@@ -572,21 +576,21 @@ def squeeze(A, axis=None):
572
576
return np .squeeze (A , axis = axis )
573
577
574
578
575
- def reshape (A , newshape , order = "C" ):
579
+ def reshape (A , shape , order = "C" ):
576
580
"Gives a new shape to an array without changing its data."
577
581
if isinstance (A , Tensor ):
578
582
δtrax = δ (A ).shape [len (A .shape ) :]
579
583
Δtrax = Δ (A ).shape [len (A .shape ) :]
580
584
Δδtrax = Δδ (A ).shape [len (A .shape ) :]
581
585
return Tensor (
582
- x = f (A ).reshape (* newshape , * A .trax , order = order ),
583
- δx = δ (A ).reshape (* newshape , * δtrax , order = order ),
584
- Δx = Δ (A ).reshape (* newshape , * Δtrax , order = order ),
585
- Δδx = Δδ (A ).reshape (* newshape , * Δδtrax , order = order ),
586
+ x = f (A ).reshape (* shape , * A .trax , order = order ),
587
+ δx = δ (A ).reshape (* shape , * δtrax , order = order ),
588
+ Δx = Δ (A ).reshape (* shape , * Δtrax , order = order ),
589
+ Δδx = Δδ (A ).reshape (* shape , * Δδtrax , order = order ),
586
590
ntrax = A .ntrax ,
587
591
)
588
592
else :
589
- return np .reshape (A , newshape = newshape , order = order )
593
+ return np .reshape (A , shape , order = order )
590
594
591
595
592
596
def einsum4 (subscripts , * operands ):
0 commit comments