diff --git a/lineax/_operator.py b/lineax/_operator.py index d8ab745..f5f3663 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1514,7 +1514,11 @@ def is_diagonal(operator: AbstractLinearOperator) -> bool: @is_diagonal.register(JacobianLinearOperator) @is_diagonal.register(FunctionLinearOperator) def _(operator): - return diagonal_tag in operator.tags + return ( + diagonal_tag in operator.tags + or operator.in_size() == 1 + and operator.out_size() == 1 + ) @is_diagonal.register(IdentityLinearOperator) @@ -1525,7 +1529,7 @@ def _(operator): @is_diagonal.register(TridiagonalLinearOperator) def _(operator): - return False + return operator.in_size() == 1 # is_tridiagonal