Skip to content

Commit 4246761

Browse files
Merge pull request #372 from gridap/filtered_array
added filtered arrays
2 parents 7a14914 + 6331ad4 commit 4246761

File tree

4 files changed

+132
-0
lines changed

4 files changed

+132
-0
lines changed

src/Arrays/Arrays.jl

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ export setsize!
4141
export CompressedArray
4242
export LocalToGlobalArray
4343
export LocalToGlobalPosNegArray
44+
export FilteredCellArray
4445

4546
export kernel_cache
4647
export kernel_caches
@@ -118,6 +119,8 @@ include("LocalToGlobalArrays.jl")
118119

119120
include("LocalToGlobalPosNegArrays.jl")
120121

122+
include("FilteredArrays.jl")
123+
121124
include("Reindex.jl")
122125

123126
include("IdentityVectors.jl")

src/Arrays/FilteredArrays.jl

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
2+
"""
3+
Array of vectors that combines an `AbstractArray{Array{T,M},N}` and another
4+
array of masks with the same structure but `Bool` entries, and returns a
5+
vector with the entries of the original array that are `true` in the mask array,
6+
i.e., the filtered entries
7+
"""
8+
struct FilteredCellArray{T,M,N,L,V} <: AbstractArray{Array{T,M},N}
9+
cell_values::L
10+
cell_filters::V
11+
12+
@doc """
13+
"""
14+
function FilteredCellArray(
15+
cell_values::AbstractArray{<:AbstractArray},
16+
cell_filters::AbstractArray{<:AbstractArray})
17+
18+
@assert size(cell_values) == size(cell_filters) "Global arrays mismatch"
19+
20+
T = eltype(eltype(cell_values))
21+
22+
L = typeof(cell_values)
23+
V = typeof(cell_filters)
24+
25+
# M = ndims(eltype(L))
26+
M = 1 # The result is always a vector with the entries
27+
N = ndims(L)
28+
29+
O = ndims(V)
30+
31+
@assert N == O "Local arrays dim mismatch"
32+
33+
@assert (eltype(eltype(cell_filters))<: Bool) "Filters are not Booleans"
34+
35+
new{T,M,N,L,V}(cell_values,cell_filters)
36+
37+
end
38+
end
39+
40+
size(a::FilteredCellArray) = size(a.cell_values)
41+
42+
function IndexStyle(::Type{<:FilteredCellArray{T,M,N,L}}) where {T,M,N,L}
43+
IndexStyle(L)
44+
end
45+
46+
function array_cache(a::FilteredCellArray)
47+
vals = testitem(a.cell_values)
48+
T = eltype(eltype(a.cell_values))
49+
r = zeros(T,length(vals))
50+
c = CachedArray(r)
51+
cv = array_cache(a.cell_values)
52+
cb = array_cache(a.cell_filters)
53+
(cv,cb,c)
54+
end
55+
56+
function getindex!(cache,a::FilteredCellArray,i::Integer...)
57+
(cv,cb,c) = cache
58+
vals = getindex!(cv,a.cell_values,i...)
59+
filters = getindex!(cb,a.cell_filters,i...)
60+
@assert size(vals) == size(filters) "Local arrays mismatch"
61+
setsize!(c,(sum(filters),))
62+
r = c.array
63+
i = 0
64+
for (val,filter) in zip(vals,filters)
65+
if filter
66+
i += 1
67+
r[i] = val
68+
end
69+
end
70+
r
71+
end
72+
73+
function getindex(a::FilteredCellArray,i::Integer...)
74+
cache = array_cache(a)
75+
getindex!(cache,a,i...)
76+
end
+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
module FilteredArrays
2+
3+
using Test
4+
using Gridap.Arrays
5+
using Gridap
6+
using FillArrays
7+
8+
v1 = collect(1:4)
9+
v2 = collect(5:8)
10+
v3 = collect(9:12)
11+
values = [v1,v2,v3]
12+
ptrs = [1,2,3,3,2,2]
13+
a = CompressedArray(values,ptrs)
14+
r = values[ptrs]
15+
test_array(a,r)
16+
17+
filter = [false, true, true, false]
18+
filters = Fill(filter,6)
19+
20+
fa = FilteredCellArray(a,filters)
21+
22+
test_array(fa,[a[i][2:3] for i in 1:6])
23+
24+
#
25+
26+
v1 = [ [2 4 3 3]; [2 4 3 3]]
27+
v2 = [ [5 8 4 3]; [2 4 4 4]]
28+
v3 = [ [2 2 8 8]; [2 9 8 6]]
29+
30+
values = [v1,v2,v3]
31+
ptrs = [1,2,3,3,2,2]
32+
a = CompressedArray(values,ptrs)
33+
r = values[ptrs]
34+
test_array(a,r)
35+
36+
b = Array{Bool,2}(undef,2,4)
37+
b .= true
38+
b[2,3] = b[1,4] = false
39+
40+
filters = Fill(b,6)
41+
42+
fa = FilteredCellArray(a,filters)
43+
44+
r1 = [2, 2, 4, 4, 3, 3]
45+
r2 = [5, 2, 8, 4, 4, 4]
46+
r3 = [2, 2, 2, 9, 8, 6]
47+
res = [r1,r2,r3]
48+
r = CompressedArray(res,ptrs)
49+
test_array(fa,r)
50+
51+
end #module

test/ArraysTests/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ using Test
1616

1717
@testset "LocalToGlobalPosNegArrays" begin include("LocalToGlobalPosNegArraysTests.jl") end
1818

19+
@testset "FilteredArraysTests" begin include("FilteredArraysTests.jl") end
20+
1921
@testset "Tables" begin include("TablesTests.jl") end
2022

2123
@testset "Reindex" begin include("ReindexTests.jl") end

0 commit comments

Comments
 (0)