Skip to content

Commit 4d58afc

Browse files
committed
Default diagonal impl via unit vector mults
1 parent b4a6195 commit 4d58afc

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

src/probnum/linops/_linear_operator.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -738,19 +738,7 @@ def _trace(self) -> np.number:
738738
trace : float
739739
Trace of the linear operator.
740740
"""
741-
742-
vec = np.zeros(self.shape[1], dtype=self.dtype)
743-
744-
vec[0] = 1
745-
trace = (self @ vec)[0]
746-
vec[0] = 0
747-
748-
for i in range(1, self.shape[0]):
749-
vec[i] = 1
750-
trace += (self @ vec)[i]
751-
vec[i] = 0
752-
753-
return trace
741+
return np.sum(self.diagonal())
754742

755743
def trace(self) -> np.number:
756744
r"""Trace of the linear operator.
@@ -783,7 +771,16 @@ def _diagonal(self) -> np.ndarray:
783771
784772
You may implement this method in a subclass.
785773
"""
786-
return np.diagonal(self.todense(cache=False))
774+
D = np.min(self.shape)
775+
diag = np.zeros(D, dtype=self.dtype)
776+
vec = np.zeros(self.shape[1], dtype=self.dtype)
777+
778+
for i in range(D):
779+
vec[i] = 1
780+
diag[i] = (self @ vec)[i]
781+
vec[i] = 0
782+
783+
return diag
787784

788785
def diagonal(self) -> np.ndarray:
789786
"""Diagonal of the linear operator."""
@@ -1618,20 +1615,23 @@ def __init__(self, A: Union[ArrayLike, scipy.sparse.spmatrix]):
16181615
matmul = LinearOperator.broadcast_matmat(lambda x: self.A @ x)
16191616
todense = self.A.toarray
16201617
trace = lambda: self.A.diagonal().sum()
1618+
diagonal = lambda: self.A.diagonal()
16211619
else:
16221620
self.A = np.asarray(A)
16231621
self.A.setflags(write=False)
16241622

16251623
matmul = lambda x: self.A @ x
16261624
todense = lambda: self.A
16271625
trace = lambda: np.trace(self.A)
1626+
diagonal = lambda: np.diagonal(self.A)
16281627

16291628
super().__init__(
16301629
self.A.shape,
16311630
self.A.dtype,
16321631
matmul=matmul,
16331632
todense=todense,
16341633
trace=trace,
1634+
diagonal=diagonal,
16351635
)
16361636

16371637
def _transpose(self) -> "Matrix":

0 commit comments

Comments
 (0)