|
27 | 27 | push!(V,v)
|
28 | 28 | nothing
|
29 | 29 | end
|
| 30 | + |
| 31 | +#@inline function add_entries!( |
| 32 | +# combine::Function, |
| 33 | +# A::SparseMatrixCSC, |
| 34 | +# vs::AbstractMatrix{<:Number}, |
| 35 | +# is,js) |
| 36 | +# |
| 37 | +# if issorted(is) |
| 38 | +# nz = A.nzval |
| 39 | +# ptrs = A.colptr |
| 40 | +# rows = A.rowval |
| 41 | +# for (lj,j) in enumerate(js) |
| 42 | +# if j>0 |
| 43 | +# pini = ptrs[j] |
| 44 | +# pend = ptrs[j+1]-1 |
| 45 | +# li = 1 |
| 46 | +# for p in pini:pend |
| 47 | +# _i = rows[p] |
| 48 | +# if _i == is[li] |
| 49 | +# vij = vs[li,lj] |
| 50 | +# Aij = nz[p] |
| 51 | +# nz[p] = combine(Aij,vij) |
| 52 | +# li += 1 |
| 53 | +# end |
| 54 | +# end |
| 55 | +# end |
| 56 | +# end |
| 57 | +# else |
| 58 | +# for (lj,j) in enumerate(js) |
| 59 | +# if j>0 |
| 60 | +# for (li,i) in enumerate(is) |
| 61 | +# if i>0 |
| 62 | +# vij = vs[li,lj] |
| 63 | +# add_entry!(combine,A,vij,i,j) |
| 64 | +# end |
| 65 | +# end |
| 66 | +# end |
| 67 | +# end |
| 68 | +# end |
| 69 | +# A |
| 70 | +#end |
| 71 | + |
| 72 | +struct CounterCSRR{Tv,Ti} |
| 73 | + tv::Type{Tv} |
| 74 | + nrows::Int |
| 75 | + ncols::Int |
| 76 | + rowptrs::Vector{Ti} |
| 77 | +end |
| 78 | + |
| 79 | +LoopStyle(::Type{<:CounterCSRR}) = Loop() |
| 80 | + |
| 81 | +@inline function add_entry!(::typeof(+),a::CounterCSRR{Tv,Ti},v,i,j) where {Tv,Ti} |
| 82 | + a.rowptrs[i+1] += Ti(1) |
| 83 | + nothing |
| 84 | +end |
| 85 | + |
| 86 | +struct CSRR{Tv,Ti} |
| 87 | + nrows::Int |
| 88 | + ncols::Int |
| 89 | + rowptrs::Vector{Ti} |
| 90 | + colvals::Vector{Ti} |
| 91 | + nzvals::Vector{Tv} |
| 92 | +end |
| 93 | + |
| 94 | +LoopStyle(::Type{<:CSRR}) = Loop() |
| 95 | + |
| 96 | +@inline function add_entry!(::typeof(+),a::CSRR{Tv,Ti},v::Nothing,i,j) where {Tv,Ti} |
| 97 | + p = a.rowptrs[i] |
| 98 | + a.colvals[p] = j |
| 99 | + a.rowptrs[i] = p+Ti(1) |
| 100 | + nothing |
| 101 | +end |
| 102 | + |
| 103 | +@inline function add_entry!(::typeof(+),a::CSRR{Tv,Ti},v,i,j) where {Tv,Ti} |
| 104 | + p = a.rowptrs[i] |
| 105 | + a.colvals[p] = j |
| 106 | + a.nzvals[p] = v |
| 107 | + a.rowptrs[i] = p+Ti(1) |
| 108 | + nothing |
| 109 | +end |
| 110 | + |
| 111 | +function nz_counter(::Type{SparseMatrixCSC{Tv,Ti}},axes) where {Tv,Ti} |
| 112 | + nrows = length(axes[1]) |
| 113 | + ncols = length(axes[2]) |
| 114 | + rowptrs = zeros(Ti,nrows+1) |
| 115 | + CounterCSRR(Tv,nrows,ncols,rowptrs) |
| 116 | +end |
| 117 | + |
| 118 | +function nz_allocation(a::CounterCSRR{Tv,Ti}) where {Tv,Ti} |
| 119 | + rowptrs = a.rowptrs |
| 120 | + length_to_ptrs!(rowptrs) |
| 121 | + ndata = rowptrs[end]-1 |
| 122 | + colvals = zeros(Ti,ndata) |
| 123 | + nzvals = zeros(Tv,ndata) |
| 124 | + CSRR(a.nrows,a.ncols,rowptrs,colvals,nzvals) |
| 125 | +end |
| 126 | + |
| 127 | +function create_from_nz(a::CSRR{Tv,Ti}) where {Tv,Ti} |
| 128 | + rewind_ptrs!(a.rowptrs) |
| 129 | + A = _csrr_to_csc!(a) |
| 130 | + A |
| 131 | +end |
| 132 | + |
| 133 | +function _csrr_to_csc!(csrr::CSRR{Tv,Ti}) where {Tv,Ti} |
| 134 | + nrows = csrr.nrows |
| 135 | + ncols = csrr.ncols |
| 136 | + rowptrs = csrr.rowptrs |
| 137 | + colvals = csrr.colvals |
| 138 | + nzvalscsr = csrr.nzvals |
| 139 | + |
| 140 | + @assert nrows == length(rowptrs)-1 |
| 141 | + colptrs = Vector{Ti}(undef,ncols+1) |
| 142 | + work = Vector{Ti}(undef,ncols) |
| 143 | + cscnnz = _csrr_to_csc_count!(colptrs,rowptrs,colvals,nzvalscsr,work) |
| 144 | + rowvals = Vector{Ti}(undef,cscnnz) |
| 145 | + nzvalscsc = Vector{Tv}(undef,cscnnz) |
| 146 | + _csrr_to_csc_fill!(colptrs,rowvals,nzvalscsc,rowptrs,colvals,nzvalscsr) |
| 147 | + SparseMatrixCSC(nrows,ncols,colptrs,rowvals,nzvalscsc) |
| 148 | +end |
| 149 | + |
| 150 | +# Notation |
| 151 | +# csrr: csr with repeated and unsorted columns |
| 152 | +# csru: csr witu unsorted columns |
| 153 | +# csc: csc with sorted columns |
| 154 | + |
| 155 | +# Adapted form SparseArrays |
| 156 | +function _csrr_to_csc_count!( |
| 157 | + colptrs::Vector{Ti}, |
| 158 | + rowptrs::Vector{Tj}, |
| 159 | + colvals::Vector{Tj}, |
| 160 | + nzvalscsr::Vector{Tv}, |
| 161 | + work::Vector{Tj}) where {Ti,Tj,Tv} |
| 162 | + |
| 163 | + nrows = length(rowptrs)-1 |
| 164 | + ncols = length(colptrs)-1 |
| 165 | + if nrows == 0 || ncols == 0 |
| 166 | + fill!(colptrs, Ti(1)) |
| 167 | + return Tj(0) |
| 168 | + end |
| 169 | + |
| 170 | + # Convert csrr to csru by identifying repeated cols with array work. |
| 171 | + # At the same time, count number of unique rows in colptrs shifted by one. |
| 172 | + fill!(colptrs, Ti(0)) |
| 173 | + fill!(work, Tj(0)) |
| 174 | + writek = Tj(1) |
| 175 | + newcsrrowptri = Ti(1) |
| 176 | + origcsrrowptri = Tj(1) |
| 177 | + origcsrrowptrip1 = rowptrs[2] |
| 178 | + @inbounds for i in 1:nrows |
| 179 | + for readk in origcsrrowptri:(origcsrrowptrip1-Tj(1)) |
| 180 | + j = colvals[readk] |
| 181 | + if work[j] < newcsrrowptri |
| 182 | + work[j] = writek |
| 183 | + if writek != readk |
| 184 | + colvals[writek] = j |
| 185 | + nzvalscsr[writek] = nzvalscsr[readk] |
| 186 | + end |
| 187 | + writek += Tj(1) |
| 188 | + colptrs[j+1] += Ti(1) |
| 189 | + else |
| 190 | + klt = work[j] |
| 191 | + nzvalscsr[klt] = +(nzvalscsr[klt], nzvalscsr[readk]) |
| 192 | + end |
| 193 | + end |
| 194 | + newcsrrowptri = writek |
| 195 | + origcsrrowptri = origcsrrowptrip1 |
| 196 | + origcsrrowptrip1 != writek && (rowptrs[i+1] = writek) |
| 197 | + i < nrows && (origcsrrowptrip1 = rowptrs[i+2]) |
| 198 | + end |
| 199 | + |
| 200 | + # Convert colptrs from counts to ptrs shifted by one |
| 201 | + # (ptrs will be corrected below) |
| 202 | + countsum = Tj(1) |
| 203 | + colptrs[1] = Ti(1) |
| 204 | + @inbounds for j in 2:(ncols+1) |
| 205 | + overwritten = colptrs[j] |
| 206 | + colptrs[j] = countsum |
| 207 | + countsum += overwritten |
| 208 | + @check Base.hastypemax(Ti) && (countsum <= typemax(Ti)) |
| 209 | + end |
| 210 | + |
| 211 | + cscnnz = countsum - Tj(1) |
| 212 | + cscnnz |
| 213 | +end |
| 214 | + |
| 215 | +function _csrr_to_csc_fill!( |
| 216 | + colptrs::Vector{Ti},rowvals::Vector{Ti},nzvalscsc::Vector{Tv}, |
| 217 | + rowptrs::Vector{Tj},colvals::Vector{Tj},nzvalscsr::Vector{Tv}) where {Ti,Tj,Tv} |
| 218 | + |
| 219 | + nrows = length(rowptrs)-1 |
| 220 | + ncols = length(colptrs)-1 |
| 221 | + if nrows == 0 || ncols == 0 |
| 222 | + return nothing |
| 223 | + end |
| 224 | + |
| 225 | + # From csru to csc |
| 226 | + # Tracking write positions in colptrs corrects |
| 227 | + # the column pointers to the final value. |
| 228 | + @inbounds for i in 1:nrows |
| 229 | + for csrk in rowptrs[i]:(rowptrs[i+1]-Tj(1)) |
| 230 | + j = colvals[csrk] |
| 231 | + x = nzvalscsr[csrk] |
| 232 | + csck = colptrs[j+1] |
| 233 | + colptrs[j+1] = csck + Ti(1) |
| 234 | + rowvals[csck] = i |
| 235 | + nzvalscsc[csck] = x |
| 236 | + end |
| 237 | + end |
| 238 | + |
| 239 | + nothing |
| 240 | +end |
| 241 | + |
| 242 | + |
0 commit comments