Skip to content

Commit e6b18d7

Browse files
Adding support for lu+SparseMatrixCSR{0}
1 parent fdc4e90 commit e6b18d7

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

src/SparseMatrixCSR.jl

+13-1
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,12 @@ function Base.copy(a::SparseMatrixCSR{Bi}) where Bi
118118
SparseMatrixCSR{Bi}(a.m,a.n,copy(a.rowptr),copy(a.colval),copy(a.nzval))
119119
end
120120

121+
_copy_and_increment(x) = copy(x) .+ 1
122+
121123
function LinearAlgebra.lu(a::SparseMatrixCSR{0})
122-
@assert false "Base.lu(a::SparseMatrixCSR{0}) not yet implemented"
124+
rowptr = _copy_and_increment(a.rowptr)
125+
colval = _copy_and_increment(a.colval)
126+
Transpose(lu(SparseMatrixCSC(a.m,a.n,rowptr,colval,a.nzval)))
123127
end
124128

125129
function LinearAlgebra.lu(a::SparseMatrixCSR{1})
@@ -132,6 +136,14 @@ function LinearAlgebra.lu!(
132136
Transpose(lu!(translu.parent,SparseMatrixCSC(a.m,a.n,a.rowptr,a.colval,a.nzval)))
133137
end
134138

139+
function LinearAlgebra.lu!(
140+
translu::Transpose{T,<:SuiteSparse.UMFPACK.UmfpackLU{T}},
141+
a::SparseMatrixCSR{0}) where {T}
142+
rowptr = _copy_and_increment(a.rowptr)
143+
colval = _copy_and_increment(a.colval)
144+
Transpose(lu!(translu.parent,SparseMatrixCSC(a.m,a.n,rowptr,colval,a.nzval)))
145+
end
146+
135147
size(S::SparseMatrixCSR) = (S.m, S.n)
136148
IndexStyle(::Type{<:SparseMatrixCSR}) = IndexCartesian()
137149
function getindex(A::SparseMatrixCSR{Bi,T}, i0::Integer, i1::Integer) where {Bi,T}

test/SparseMatrixCSR.jl

+15-9
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ function test_csr(Bi,Tv,Ti)
8484
@test CSR*x CSC*x
8585
end
8686

87+
function test_lu(Bi,I,J,V)
88+
CSR=sparsecsr(Val(Bi),I,J,V)
89+
CSC=sparse(I,J,V)
90+
x=rand(3)
91+
@test norm(CSR\x-CSC\x) < 1.0e-14
92+
fact=lu(CSR)
93+
lu!(fact,CSR)
94+
y=similar(x)
95+
ldiv!(y,fact,x)
96+
@test norm(y-CSC\x) < 1.0e-14
97+
end
98+
99+
87100
for Bi in (0,1)
88101
for Tv in (Float32,Float64)
89102
for Ti in (Int32,Int64)
@@ -95,14 +108,7 @@ end
95108
I = [1,1,2,2,2,3,3]
96109
J = [1,2,1,2,3,2,3]
97110
V = [4.0,1.0,-1.0,4.0,1.0,-1.0,4.0]
98-
CSR=sparsecsr(I,J,V)
99-
CSC=sparse(I,J,V)
100-
x=rand(3)
101-
@test norm(CSR\x-CSC\x) < 1.0e-14
102-
fact=lu(CSR)
103-
lu!(fact,CSR)
104-
y=similar(x)
105-
ldiv!(y,fact,x)
106-
@test norm(y-CSC\x) < 1.0e-14
111+
test_lu(0,I,J,V)
112+
test_lu(1,I,J,V)
107113

108114
end # module

0 commit comments

Comments
 (0)