Skip to content

Commit 5f5ba7b

Browse files
committed
Override linop instead of _evaluate_linop
In arithmetic fallback covariance functions, override `linop` instead of `_evaluate_linop`, because the latter has the problem that the input preprocessing destroys attributes of np.array subclasses. Keeping the option of np.array subclassing can be very useful for certain linop implementations, e.g. tensor product covariance functions.
1 parent 2ec621b commit 5f5ba7b

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

src/probnum/randprocs/covfuncs/_arithmetic_fallbacks.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def __init__(self, covfunc: CovarianceFunction, scalar: ScalarLike):
5555
def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray] = None) -> np.ndarray:
5656
return self._scalar * self._covfunc(x0, x1)
5757

58-
def _evaluate_linop(
59-
self, x0: np.ndarray, x1: Optional[np.ndarray]
58+
def linop(
59+
self, x0: utils.ArrayLike, x1: Optional[utils.ArrayLike] = None
6060
) -> linops.LinearOperator:
6161
return self._scalar * self._covfunc.linop(x0, x1)
6262

@@ -82,7 +82,6 @@ class SumCovarianceFunction(CovarianceFunction):
8282
"""
8383

8484
def __init__(self, *summands: CovarianceFunction):
85-
8685
if not all(
8786
(summand.input_shape == summands[0].input_shape)
8887
and (summand.output_shape_0 == summands[0].output_shape_0)
@@ -104,8 +103,8 @@ def _evaluate(self, x0: np.ndarray, x1: Optional[np.ndarray]) -> np.ndarray:
104103
operator.add, (summand(x0, x1) for summand in self._summands)
105104
)
106105

107-
def _evaluate_linop(
108-
self, x0: np.ndarray, x1: Optional[np.ndarray]
106+
def linop(
107+
self, x0: utils.ArrayLike, x1: Optional[utils.ArrayLike] = None
109108
) -> linops.LinearOperator:
110109
return functools.reduce(
111110
operator.add, (summand.linop(x0, x1) for summand in self._summands)
@@ -151,7 +150,6 @@ class ProductCovarianceFunction(CovarianceFunction):
151150
"""
152151

153152
def __init__(self, *factors: CovarianceFunction):
154-
155153
if not all(
156154
(factor.input_shape == factors[0].input_shape)
157155
and (factor.output_shape_0 == factors[0].output_shape_0)

0 commit comments

Comments
 (0)