Skip to content

Commit

Permalink
Interpret scalar transformations as diagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar authored and patrick-kidger committed Aug 20, 2024
1 parent 9676202 commit 50f8b81
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
6 changes: 4 additions & 2 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,9 @@ def is_diagonal(operator: AbstractLinearOperator) -> bool:
@is_diagonal.register(JacobianLinearOperator)
@is_diagonal.register(FunctionLinearOperator)
def _(operator):
return diagonal_tag in operator.tags
return diagonal_tag in operator.tags or (
operator.in_size() == 1 and operator.out_size() == 1
)


@is_diagonal.register(IdentityLinearOperator)
Expand All @@ -1525,7 +1527,7 @@ def _(operator):

@is_diagonal.register(TridiagonalLinearOperator)
def _(operator):
return False
return operator.in_size() == 1


# is_tridiagonal
Expand Down
16 changes: 16 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,22 @@ def test_is_diagonal(dtype, getkey):
_assert_except_diag(lx.is_diagonal, not_diagonal_operators, flip_cond=True)


@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_is_diagonal_scalar(dtype, getkey):
matrix = jr.normal(getkey(), (1, 1), dtype=dtype)
diagonal_operators = _setup(getkey, matrix)
for operator in diagonal_operators:
assert lx.is_diagonal(operator)


@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_is_diagonal_tridiagonal(dtype, getkey):
diag1 = jr.normal(getkey(), (1,), dtype=dtype)
diag2 = jnp.zeros((0,), dtype=dtype)
op1 = lx.TridiagonalLinearOperator(diag1, diag2, diag2)
assert lx.is_diagonal(op1)


@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_has_unit_diagonal(dtype, getkey):
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
Expand Down

0 comments on commit 50f8b81

Please sign in to comment.