Skip to content

Commit bbb97b5

Browse files
committed
Bugfix in setaxes!
1 parent c71015f commit bbb97b5

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

src/Arrays/CachedArrays.jl

+45-2
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,57 @@ function similar(::Type{CachedArray{T,N,A}},s::Tuple{Vararg{Int}}) where {T,N,A}
132132
end
133133

134134
function setaxes!(a::CachedArray,ax)
135-
s = map(length,ax)
136-
if s != size(a.array)
135+
if ! _same_axes(axes(a.array),ax)
136+
s = map(length,ax)
137137
if haskey(a.buffer,s)
138138
a.array = a.buffer[s]
139+
if ! _same_axes(axes(a.array),ax)
140+
a.array = similar(a.array,ax)
141+
a.buffer[s] = a.array
142+
end
139143
else
140144
a.array = similar(a.array,ax)
141145
a.buffer[s] = a.array
142146
end
143147
end
148+
nothing
149+
end
150+
151+
function _same_axes(a,b)
152+
a === b || a == b
153+
end
154+
155+
function _same_axes(a::NTuple{N,BlockedUnitRange},b::NTuple{N,BlockedUnitRange}) where N
156+
if a === b
157+
true
158+
else
159+
all(map(_same_axes_1d,a,b))
160+
end
161+
end
162+
163+
_same_axes_1d(a::BlockedUnitRange,b::BlockedUnitRange) = blocklasts(a) == blocklasts(b)
164+
165+
function _same_axes(a::NTuple{N,TwoLevelBlockedUnitRange},b::NTuple{N,TwoLevelBlockedUnitRange}) where N
166+
if a === b
167+
true
168+
else
169+
all(map(_same_axes_1d,a,b))
170+
end
171+
end
172+
173+
function _same_axes_1d(a::TwoLevelBlockedUnitRange,b::TwoLevelBlockedUnitRange)
174+
r = _same_axes_1d(a.global_range,b.global_range)
175+
la = length(a.local_ranges)
176+
lb = length(b.local_ranges)
177+
if la!=lb
178+
return false
179+
else
180+
for i in 1:la
181+
@inbounds ra = a.local_ranges[i]
182+
@inbounds rb = b.local_ranges[i]
183+
r = r && _same_axes_1d(ra,rb)
184+
end
185+
return r
186+
end
144187
end
145188

test/ArraysTests/BlockArraysCooTests.jl

+27
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,25 @@ cc = CachedArray(c)
9292

9393
axs = (blockedrange([2,3,3]), blockedrange([2,3,3]))
9494
setaxes!(cc,axs)
95+
@test map(blocklasts,axes(cc.array)) == map(blocklasts,axs)
9596
@test cc.array === c
9697

9798
axs = (blockedrange([4,5,3]), blockedrange([2,4,3]))
9899
setaxes!(cc,axs)
100+
@test map(blocklasts,axes(cc.array)) == map(blocklasts,axs)
99101
fill!(cc.array,0)
100102

103+
axs = (blockedrange([5,4,3]), blockedrange([4,2,3]))
104+
setaxes!(cc,axs)
105+
@test map(blocklasts,axes(cc.array)) == map(blocklasts,axs)
106+
107+
axs1 = (blockedrange([5,4,3]), blockedrange([4,2,3]))
108+
axs2 = (blockedrange([4,5,3]), blockedrange([2,4,3]))
109+
axs3 = (blockedrange([4,5,3]), blockedrange([2,4,3]))
110+
@test Arrays._same_axes(axs1,axs2) == false
111+
@test Arrays._same_axes(axs2,axs2)
112+
@test Arrays._same_axes(axs2,axs3)
113+
101114
blocks = [ 10*[1,2], 20*[1,2,3] ]
102115
blockids = [(1,),(3,)]
103116
axs = (blockedrange([2,4,3]),)
@@ -106,10 +119,12 @@ b = BlockArrayCoo(blocks,blockids,axs)
106119
cb = CachedArray(b)
107120

108121
setaxes!(cb,axs)
122+
@test map(blocklasts,axes(cb.array)) == map(blocklasts,axs)
109123
@test cb.array === b
110124

111125
axs = (blockedrange([3,2,3]),)
112126
setaxes!(cb,axs)
127+
@test map(blocklasts,axes(cb.array)) == map(blocklasts,axs)
113128
@test size(cb) == (8,)
114129

115130
c = copy(a)
@@ -254,4 +269,16 @@ mul!(rS,aS,bS,3,2)
254269
@test isa(rS,BlockArrayCoo)
255270
@test isa(rS[Block(2)],BlockArrayCoo)
256271

272+
axs1 = (blockedrange([5,4,3]), blockedrange([4,2,3]))
273+
axs2 = (blockedrange([4,5,3]), blockedrange([2,4,3]))
274+
axsA = (blockedrange([axs1[1],axs1[1]]),blockedrange([axs2[2],axs2[2]]))
275+
axsB = (blockedrange([axs2[1],axs2[1]]),blockedrange([axs1[2],axs1[2]]))
276+
axsC = (blockedrange([axs2[1],axs2[1]]),blockedrange([axs1[2],axs1[2]]))
277+
@test Arrays._same_axes(axsA,axsB) == false
278+
@test Arrays._same_axes(axsA,axsA)
279+
@test Arrays._same_axes(axsB,axsC)
280+
281+
#using BenchmarkTools
282+
#@btime Arrays._same_axes($axsA,$axsA)
283+
257284
end # module

0 commit comments

Comments
 (0)