Skip to content

Commit

Permalink
Interpret scalar transformations as diagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Aug 18, 2024
1 parent 4a7b108 commit ec238dd
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,11 @@ def is_diagonal(operator: AbstractLinearOperator) -> bool:
@is_diagonal.register(JacobianLinearOperator)
@is_diagonal.register(FunctionLinearOperator)
def _(operator):
return diagonal_tag in operator.tags
return (
diagonal_tag in operator.tags
or operator.in_size() == 1
and operator.out_size() == 1
)


@is_diagonal.register(IdentityLinearOperator)
Expand All @@ -1525,7 +1529,7 @@ def _(operator):

@is_diagonal.register(TridiagonalLinearOperator)
def _(operator):
return False
return operator.in_size() == 1


# is_tridiagonal
Expand Down

0 comments on commit ec238dd

Please sign in to comment.