-
Notifications
You must be signed in to change notification settings - Fork 102
/
Copy pathCompressedCellValues.jl
219 lines (173 loc) · 5.44 KB
/
CompressedCellValues.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
module CompressedCellValues
using Gridap
using Gridap.CellValuesGallery
using Gridap.CellNumberApply: IndexCellNumberFromKernel, CellNumberFromKernel
using Gridap.CellArrayApply: IndexCellArrayFromKernel, CellArrayFromKernel
using Gridap.CellMaps: IterCellMapValue, IndexCellMapValue
export IterCompressedCellValue
export IndexCompressedCellValue
export CompressedCellValue
export CompressedCellArray
export CompressedCellMap
import Base: iterate
import Base: length
import Base: size
import Base: getindex
import Gridap: apply
import Gridap: reindex
import Gridap: evaluate
import Base: ==, ≈
struct IterCompressedCellValue{T,A} <: IterCellValue{T}
values::Vector{T}
ptrs::A
function IterCompressedCellValue(
values::Vector{T}, ptrs) where T
@assert hasmethod(iterate,Tuple{typeof(ptrs)})
@assert hasmethod(length,Tuple{typeof(ptrs)})
A = typeof(ptrs)
new{T,A}(values,ptrs)
end
end
@inline function iterate(cv::IterCompressedCellValue)
inext = iterate(cv.ptrs)
_iterate(cv,inext)
end
@inline function iterate(cv::IterCompressedCellValue,state)
inext = iterate(cv.ptrs,state)
_iterate(cv,inext)
end
@inline function _iterate(cv,inext)
if inext === nothing; return nothing; end
i, istate = inext
v = cv.values[i]
(v,istate)
end
length(cv::IterCompressedCellValue) = length(cv.ptrs)
struct IndexCompressedCellValue{T,A} <: IndexCellValue{T,1}
values::Vector{T}
ptrs::A
function IndexCompressedCellValue(
values::Vector{T}, ptrs::AbstractArray) where T
A = typeof(ptrs)
new{T,A}(values,ptrs)
end
end
function IndexCompressedCellValue(cv::ConstantCellValue)
values = [cv.value,]
ptrs = ConstantCellValue(1,length(cv))
IndexCompressedCellValue(values,ptrs)
end
function getindex(
cv::IndexCompressedCellValue{T,A}, i::Integer) where {T,A}
j = cv.ptrs[i]
cv.values[j]
end
size(cv::IndexCompressedCellValue) = (length(cv.ptrs),)
const CompressedCellValue{T} = Union{
IterCompressedCellValue{T},IndexCompressedCellValue{T}}
function CompressedCellValue(values::Vector{T},ptrs::AbstractArray) where T
IndexCompressedCellValue(values,ptrs)
end
function CompressedCellValue(values::Vector{T},ptrs) where T
IterCompressedCellValue(values,ptrs)
end
function (==)(a::CompressedCellValue,b::CompressedCellValue)
_eq_kernel(==,a,b)
end
function (≈)(a::CompressedCellValue,b::CompressedCellValue)
_eq_kernel(≈,a,b)
end
function _eq_kernel(op,a,b)
!(op(a.values,b.values)) && return false
!( a.ptrs == b.ptrs ) && return false
length(a) != length(b) && return false
return true
end
function apply(k::NumberKernel,v::Vararg{<:CompressedCellValue})
optim = _is_optimizable(v...)
_apply(Val(optim),k,v...)
end
function apply(k::ArrayKernel,v::Vararg{<:CompressedCellValue})
optim = _is_optimizable(v...)
_apply(Val(optim),k,v...)
end
function _is_optimizable(v...)
@assert length(v) > 0
v1, = v
all( [ _is_compatible_data(vi,v1) for vi in v] )
end
function _is_compatible_data(a,b)
if length(a.values) != length(b.values)
return false
end
if a.ptrs === b.ptrs
return true
end
if a.ptrs == b.ptrs
return true
end
false
end
function _apply(::Val{true},k,v...)
v1, = v
n = length(v1.values)
input_values = [ [ vi.values[i] for vi in v] for i in 1:n]
values = [ compute_value(k,vals...) for vals in input_values ]
CompressedCellValue(values,v1.ptrs)
end
function _apply(::Val{false},k::NumberKernel,v::Vararg{<:IndexCompressedCellValue})
IndexCellNumberFromKernel(k,v...)
end
function _apply(::Val{false},k::NumberKernel,v::Vararg{<:CompressedCellValue})
CellNumberFromKernel(k,v...)
end
function _apply(::Val{false},k::ArrayKernel,v::Vararg{<:IndexCompressedCellValue})
IndexCellArrayFromKernel(k,v...)
end
function _apply(::Val{false},k::ArrayKernel,v::Vararg{<:CompressedCellValue})
CellArrayFromKernel(k,v...)
end
const CompressedCellArray{T,N} = CompressedCellValue{<:AbstractArray{T,N}}
function CompressedCellArray(values::Vector{<:AbstractArray},ptrs)
CompressedCellValue(values,ptrs)
end
const CompressedCellMap{S,M,T,N} = CompressedCellValue{<:Map{S,M,T,N}}
function CompressedCellMap(values::Vector{<:Map},ptrs)
CompressedCellValue(values,ptrs)
end
function evaluate(cm::ConstantCellMap{S,M},ca::CompressedCellArray{<:S,M}) where {S,M}
@assert length(cm) == length(ca)
m = cm.value
rs = [ evaluate(m,a) for a in ca.values]
CompressedCellValue(rs,ca.ptrs)
end
function evaluate(cm::CompressedCellMap{S,M},ca::CompressedCellArray{<:S,M}) where {S,M}
@assert length(cm) == length(ca)
optim = _is_optimizable(cm,ca)
_evaluate(Val(optim),cm,ca)
end
function _evaluate(::Val{true},cm::CompressedCellMap,ca::CompressedCellValue)
n = length(cm.values)
input_values = [ ( cm.values[i], ca.values[i] ) for i in 1:n]
values = [ evaluate(mi,ai) for (mi,ai) in input_values ]
CompressedCellValue(values,cm.ptrs)
end
function _evaluate(::Val{false},cm::CellMap,ca::CellArray)
IterCellMapValue(cm,ca)
end
function _evaluate(::Val{false},cm::IndexCellMap,ca::IndexCellArray)
IndexCellMapValue(cm,ca)
end
function reindex(values::CompressedCellValue, indices::CellValue{<:IndexLike})
vals = values.values
ptrs = reindex(_ptrs(values.ptrs),indices)
CompressedCellValue(vals,ptrs)
end
function reindex(values::CompressedCellValue, indices::IndexCellValue{<:IndexLike})
vals = values.values
ptrs = reindex(_ptrs(values.ptrs),indices)
CompressedCellValue(vals,ptrs)
end
_ptrs(p::CellValue) = p
_ptrs(p::AbstractArray) = CellValueFromArray(p)
end # module