Skip to content

Commit

Permalink
fix isagg to correctly use a fast path (#2357)
Browse files Browse the repository at this point in the history
  • Loading branch information
bkamins authored Aug 13, 2020
1 parent 07c3bc9 commit 4c601bc
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 35 deletions.
122 changes: 87 additions & 35 deletions src/groupeddataframe/splitapplycombine.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# this constant defines which types of values returned by aggregation function
# in combine are considered to produce multiple columns in the resulting data frame
const MULTI_COLS_TYPE = Union{AbstractDataFrame, NamedTuple, DataFrameRow, AbstractMatrix}

"""
groupby(d::AbstractDataFrame, cols; sort=false, skipmissing=false)
Expand Down Expand Up @@ -452,7 +456,7 @@ function combine(p::Pair, gd::GroupedDataFrame;
# verify if it is not better to use a fast path, which we achieve
# by moving to combine(::GroupedDataFrame, ::AbstractVector) method
# note that even if length(gd) == 0 we can do this step
if isagg(p_from => (p_to isa Pair ? first(p_to) : p_to)) || p_from === nrow
if isagg(p_from => (p_to isa Pair ? first(p_to) : p_to), gd) || p_from === nrow
return combine(gd, [p], keepkeys=keepkeys, ungroup=ungroup)
end

Expand Down Expand Up @@ -760,17 +764,35 @@ struct Reduce{O, C, A} <: AbstractAggregate
end
Reduce(f, condf=nothing, adjust=nothing) = Reduce(f, condf, adjust, false)

check_aggregate(f::Any) = f
check_aggregate(::typeof(sum)) = Reduce(Base.add_sum)
check_aggregate(::typeof(prod)) = Reduce(Base.mul_prod)
check_aggregate(::typeof(maximum)) = Reduce(max)
check_aggregate(::typeof(minimum)) = Reduce(min)
check_aggregate(::typeof(mean)) = Reduce(Base.add_sum, nothing, /)
check_aggregate(::typeof(sumskipmissing)) = Reduce(Base.add_sum, !ismissing)
check_aggregate(::typeof(prodskipmissing)) = Reduce(Base.mul_prod, !ismissing)
check_aggregate(::typeof(meanskipmissing)) = Reduce(Base.add_sum, !ismissing, /)
check_aggregate(::typeof(maximumskipmissing)) = Reduce(max, !ismissing, nothing, true)
check_aggregate(::typeof(minimumskipmissing)) = Reduce(min, !ismissing, nothing, true)
check_aggregate(f::Any, ::AbstractVector) = f
check_aggregate(f::typeof(sum), ::AbstractVector{<:Union{Missing, Number}}) =
Reduce(Base.add_sum)
check_aggregate(f::typeof(sumskipmissing), ::AbstractVector{<:Union{Missing, Number}}) =
Reduce(Base.add_sum, !ismissing)
check_aggregate(f::typeof(prod), ::AbstractVector{<:Union{Missing, Number}}) =
Reduce(Base.mul_prod)
check_aggregate(f::typeof(prodskipmissing), ::AbstractVector{<:Union{Missing, Number}}) =
Reduce(Base.mul_prod, !ismissing)
check_aggregate(f::typeof(maximum),
::AbstractVector{<:Union{Missing, MULTI_COLS_TYPE, AbstractVector}}) = f
check_aggregate(f::typeof(maximum), v::AbstractVector{<:Union{Missing, Real}}) =
eltype(v) === Any ? f : Reduce(max)
check_aggregate(f::typeof(maximumskipmissing),
::AbstractVector{<:Union{Missing, MULTI_COLS_TYPE, AbstractVector}}) = f
check_aggregate(f::typeof(maximumskipmissing), v::AbstractVector{<:Union{Missing, Real}}) =
eltype(v) === Any ? f : Reduce(max, !ismissing, nothing, true)
check_aggregate(f::typeof(minimum),
::AbstractVector{<:Union{Missing, MULTI_COLS_TYPE, AbstractVector}}) = f
check_aggregate(f::typeof(minimum), v::AbstractVector{<:Union{Missing, Real}}) =
eltype(v) === Any ? f : Reduce(min)
check_aggregate(f::typeof(minimumskipmissing),
::AbstractVector{<:Union{Missing, MULTI_COLS_TYPE, AbstractVector}}) = f
check_aggregate(f::typeof(minimumskipmissing), v::AbstractVector{<:Union{Missing, Real}}) =
eltype(v) === Any ? f : Reduce(min, !ismissing, nothing, true)
check_aggregate(f::typeof(mean), ::AbstractVector{<:Union{Missing, Number}}) =
Reduce(Base.add_sum, nothing, /)
check_aggregate(f::typeof(meanskipmissing), ::AbstractVector{<:Union{Missing, Number}}) =
Reduce(Base.add_sum, !ismissing, /)

# Other aggregate functions which are not strictly reductions
struct Aggregate{F, C} <: AbstractAggregate
Expand All @@ -779,15 +801,32 @@ struct Aggregate{F, C} <: AbstractAggregate
end
Aggregate(f) = Aggregate(f, nothing)

check_aggregate(::typeof(var)) = Aggregate(var)
check_aggregate(::typeof(varskipmissing)) = Aggregate(var, !ismissing)
check_aggregate(::typeof(std)) = Aggregate(std)
check_aggregate(::typeof(stdskipmissing)) = Aggregate(std, !ismissing)
check_aggregate(::typeof(first)) = Aggregate(first)
check_aggregate(::typeof(firstskipmissing)) = Aggregate(first, !ismissing)
check_aggregate(::typeof(last)) = Aggregate(last)
check_aggregate(::typeof(lastskipmissing)) = Aggregate(last, !ismissing)
check_aggregate(::typeof(length)) = Aggregate(length)
check_aggregate(f::typeof(var), ::AbstractVector{<:Union{Missing, Number}}) =
Aggregate(var)
check_aggregate(f::typeof(varskipmissing), ::AbstractVector{<:Union{Missing, Number}}) =
Aggregate(var, !ismissing)
check_aggregate(f::typeof(std), ::AbstractVector{<:Union{Missing, Number}}) =
Aggregate(std)
check_aggregate(f::typeof(stdskipmissing), ::AbstractVector{<:Union{Missing, Number}}) =
Aggregate(std, !ismissing)
check_aggregate(f::typeof(first), v::AbstractVector) =
eltype(v) === Any ? f : Aggregate(first)
check_aggregate(f::typeof(first),
::AbstractVector{<:Union{Missing, MULTI_COLS_TYPE, AbstractVector}}) = f
check_aggregate(f::typeof(firstskipmissing), v::AbstractVector) =
eltype(v) === Any ? f : Aggregate(first, !ismissing)
check_aggregate(f::typeof(firstskipmissing),
::AbstractVector{<:Union{Missing, MULTI_COLS_TYPE, AbstractVector}}) = f
check_aggregate(f::typeof(last), v::AbstractVector) =
eltype(v) === Any ? f : Aggregate(last)
check_aggregate(f::typeof(last),
::AbstractVector{<:Union{Missing, MULTI_COLS_TYPE, AbstractVector}}) = f
check_aggregate(f::typeof(lastskipmissing), v::AbstractVector) =
eltype(v) === Any ? f : Aggregate(last, !ismissing)
check_aggregate(f::typeof(lastskipmissing),
::AbstractVector{<:Union{Missing, MULTI_COLS_TYPE, AbstractVector}}) = f
check_aggregate(f::typeof(length), ::AbstractVector) = Aggregate(length)

# SkipMissing does not support length

# Find first value matching condition for each group
Expand Down Expand Up @@ -864,7 +903,11 @@ function groupreduce_init(op, condf, adjust,
if isconcretetype(Tnm) && applicable(initf, Tnm)
tmpv = initf(Tnm)
initv = op(tmpv, tmpv)
x = adjust isa Nothing ? initv : adjust(initv, 1)
if adjust isa Nothing
x = Tnm <: AbstractIrrational ? float(initv) : initv
else
x = adjust(initv, 1)
end
if condf === !ismissing
V = typeof(x)
else
Expand Down Expand Up @@ -900,7 +943,8 @@ for (op, initf) in ((:max, :typemin), (:min, :typemax))
# It is safe to use a non-missing init value
# since missing will poison the result if present
# we assume here that groups are non-empty (current design assures this)
if isconcretetype(S) && hasmethod($initf, Tuple{S})
# + workaround for https://github.com/JuliaLang/julia/issues/36978
if isconcretetype(S) && hasmethod($initf, Tuple{S}) && !(S <: Irrational)
fill!(outcol, $initf(S))
else
fillfirst!(condf, outcol, incol, gd)
Expand Down Expand Up @@ -994,6 +1038,12 @@ groupreduce(f, op, condf::typeof(!ismissing), adjust, checkempty::Bool,
(r::Reduce)(incol::AbstractVector, gd::GroupedDataFrame) =
groupreduce((x, i) -> x, r.op, r.condf, r.adjust, r.checkempty, incol, gd)

# this definition is missing in Julia 1.0 LTS and is required by aggregation for var
# TODO: remove this when we drop 1.0 support
if VERSION < v"1.1"
Base.zero(::Type{Missing}) = missing
end

function (agg::Aggregate{typeof(var)})(incol::AbstractVector, gd::GroupedDataFrame)
means = groupreduce((x, i) -> x, Base.add_sum, agg.condf, /, false, incol, gd)
# !ismissing check is purely an optimization to avoid a copy later
Expand All @@ -1003,14 +1053,18 @@ function (agg::Aggregate{typeof(var)})(incol::AbstractVector, gd::GroupedDataFra
T = real(eltype(means))
end
res = zeros(T, length(gd))
groupreduce!(res, (x, i) -> @inbounds(abs2(x - means[i])), +, agg.condf,
(x, l) -> l <= 1 ? oftype(x / (l-1), NaN) : x / (l-1),
false, incol, gd)
return groupreduce!(res, (x, i) -> @inbounds(abs2(x - means[i])), +, agg.condf,
(x, l) -> l <= 1 ? oftype(x / (l-1), NaN) : x / (l-1),
false, incol, gd)
end

function (agg::Aggregate{typeof(std)})(incol::AbstractVector, gd::GroupedDataFrame)
outcol = Aggregate(var, agg.condf)(incol, gd)
map!(sqrt, outcol, outcol)
if eltype(outcol) <: Union{Missing, Rational}
return sqrt.(outcol)
else
return map!(sqrt, outcol, outcol)
end
end

for f in (first, last)
Expand Down Expand Up @@ -1038,10 +1092,8 @@ function (agg::Aggregate{typeof(length)})(incol::AbstractVector, gd::GroupedData
end
end

isagg(p::Pair) =
check_aggregate(last(p)) isa AbstractAggregate && first(p) isa ColumnIndex

const MULTI_COLS_TYPE = Union{AbstractDataFrame, NamedTuple, DataFrameRow, AbstractMatrix}
isagg((col, fun)::Pair, gdf::GroupedDataFrame) =
col isa ColumnIndex && check_aggregate(fun, parent(gdf)[!, col]) isa AbstractAggregate

function _agg2idx_map_helper(idx, idx_agg)
agg2idx_map = fill(-1, length(idx))
Expand Down Expand Up @@ -1101,11 +1153,11 @@ function _combine(f::AbstractVector{<:Pair},
end

idx_agg = nothing
if length(gd) > 0 && any(isagg, f)
if length(gd) > 0 && any(x -> isagg(x, gd), f)
# Compute indices of representative rows only once for all AbstractAggregates
idx_agg = Vector{Int}(undef, length(gd))
fillfirst!(nothing, idx_agg, 1:length(gd.groups), gd)
elseif length(gd) == 0 || !all(isagg, f)
elseif length(gd) == 0 || !all(x -> isagg(x, gd), f)
# Trigger computation of indices
# This can speed up some aggregates that would not trigger this on their own
@assert gd.idx !== nothing
Expand All @@ -1114,9 +1166,9 @@ function _combine(f::AbstractVector{<:Pair},
parentdf = parent(gd)
for (i, p) in enumerate(f)
source_cols, fun = p
if length(gd) > 0 && isagg(p)
if length(gd) > 0 && isagg(p, gd)
incol = parentdf[!, source_cols]
agg = check_aggregate(last(p))
agg = check_aggregate(last(p), incol)
outcol = agg(incol, gd)
res[i] = idx_agg, outcol
elseif keeprows && fun === identity && !(source_cols isa AsTable)
Expand Down
124 changes: 124 additions & 0 deletions test/grouping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ const ≅ = isequal
isequal_typed(df1::AbstractDataFrame, df2::AbstractDataFrame) =
isequal(df1, df2) && eltype.(eachcol(df1)) == eltype.(eachcol(df2))

"""Check if passed data frames are `isequal` and have the same types of columns"""
isequal_coltyped(df1::AbstractDataFrame, df2::AbstractDataFrame) =
isequal(df1, df2) && typeof.(eachcol(df1)) == typeof.(eachcol(df2))

"""Check that groups in gd are equal to provided data frames, ignoring order"""
function isequal_unordered(gd::GroupedDataFrame,
dfs::AbstractVector{<:AbstractDataFrame})
Expand Down Expand Up @@ -2691,4 +2695,124 @@ end
@test map(nrow, gdf) == [1, 1, 1]
end

@testset "check isagg correctly uses fast path only when it should" begin
for fun in (sum, prod, mean, var, std, sumskipmissing, prodskipmissing,
meanskipmissing, varskipmissing, stdskipmissing),
col in ([1, 2, 3], [big(1.5), big(2.5), big(3.5)], [1 + 0.5im, 2 + 0.5im, 3 + 0.5im],
[true, false, true], [pi, pi, pi], [1//2, 1//3, 1//4],
Real[1, 1.5, 1//2], Number[1, 1.5, 1//2], Any[1, 1.5, 1//2],
[1, 2, missing], [big(1.5), big(2.5), missing], [1 + 0.5im, 2 + 0.5im, missing],
[true, false, missing], [pi, pi, missing], [1//2, 1//3, missing],
Union{Missing,Real}[1, 1.5, missing],
Union{Missing,Number}[1, 1.5, missing], Any[1, 1.5, missing])
gdf = groupby_checked(DataFrame(g=[1, 1, 1], x=col), :g)
@test isequal_coltyped(combine(gdf, :x => fun => :y), combine(gdf, :x => (x -> fun(x)) => :y))
end

for fun in (maximum, minimum, maximumskipmissing, minimumskipmissing),
col in ([1, 2, 3], [big(1.5), big(2.5), big(3.5)],
[true, false, true], [pi, pi, pi], [1//2, 1//3, 1//4],
Real[1, 1.5, 1//2], Number[1, 1.5, 1//2], Any[1, 1.5, 1//2],
[1, 2, missing], [big(1.5), big(2.5), missing],
[true, false, missing], [pi, pi, missing], [1//2, 1//3, missing],
Union{Missing,Real}[1, 1.5, missing],
Union{Missing,Number}[1, 1.5, missing], Any[1, 1.5, missing])
gdf = groupby_checked(DataFrame(g=[1, 1, 1], x=col), :g)
@test isequal_coltyped(combine(gdf, :x => fun => :y), combine(gdf, :x => (x -> fun(x)) => :y))
end

for fun in (first, last, length, firstskipmissing, lastskipmissing),
col in ([1, 2, 3], [big(1.5), big(2.5), big(3.5)], [1 + 0.5im, 2 + 0.5im, 3 + 0.5im],
[true, false, true], [pi, pi, pi], [1//2, 1//3, 1//4],
Real[1, 1.5, 1//2], Number[1, 1.5, 1//2], Any[1, 1.5, 1//2],
[1, 2, missing], [big(1.5), big(2.5), missing], [1 + 0.5im, 2 + 0.5im, missing],
[true, false, missing], [pi, pi, missing], [1//2, 1//3, missing],
Union{Missing,Real}[1, 1.5, missing],
Union{Missing,Number}[1, 1.5, missing], Any[1, 1.5, missing])
gdf = groupby_checked(DataFrame(g=[1, 1, 1], x=col), :g)
if fun === lastskipmissing
# corner case - it fails in slow path, but works in fast path
if eltype(col) === Any
@test_throws MethodError combine(gdf, :x => fun => :y)
else
@test isequal_coltyped(combine(gdf, :x => fun => :y),
combine(groupby_checked(dropmissing(parent(gdf)), :g), :x => fun => :y))
end
@test_throws MethodError combine(gdf, :x => (x -> fun(x)) => :y)
else
@test isequal_coltyped(combine(gdf, :x => fun => :y), combine(gdf, :x => (x -> fun(x)) => :y))
end
end

for fun in (sum, mean, var, std),
col in ([1:3, 4:6, 7:9], [1:3, 4:6, missing])
gdf = groupby_checked(DataFrame(g=[1, 1, 1], x=col), :g)
if eltype(col) >: Missing
@test_throws MethodError combine(gdf, :x => fun => :y)
@test_throws MethodError combine(gdf, :x => (x -> fun(x)) => :y)
else
@test isequal_coltyped(combine(gdf, :x => fun => :y), combine(gdf, :x => (x -> fun(x)) => :y))
end
end

for fun in (sumskipmissing, meanskipmissing),
col in ([1:3, 4:6, 7:9], [1:3, 4:6, missing])
gdf = groupby_checked(DataFrame(g=[1, 1, 1], x=col), :g)
@test isequal_coltyped(combine(gdf, :x => fun => :y), combine(gdf, :x => (x -> fun(x)) => :y))
end

# see https://github.com/JuliaLang/julia/issues/36979
for fun in (varskipmissing, stdskipmissing),
col in ([1:3, 4:6, 7:9], [1:3, 4:6, missing])
gdf = groupby_checked(DataFrame(g=[1, 1, 1], x=col), :g)
@test_throws MethodError combine(gdf, :x => fun => :y)
@test_throws MethodError combine(gdf, :x => (x -> fun(x)) => :y)
end

for fun in (maximum, minimum, maximumskipmissing, minimumskipmissing,
first, last, length, firstskipmissing, lastskipmissing),
col in ([1:3, 4:6, 7:9], [1:3, 4:6, missing])
gdf = groupby_checked(DataFrame(g=[1, 1, 1], x=col), :g)
if fun isa typeof(lastskipmissing)
@test_throws MethodError combine(gdf, :x => fun => :y)
@test_throws MethodError combine(gdf, :x => (x -> fun(x)) => :y)
else
@test isequal_coltyped(combine(gdf, :x => fun => :y), combine(gdf, :x => (x -> fun(x)) => :y))
end
end

for fun in (prod, prodskipmissing),
col in ([1:3, 4:6, 7:9], [1:3, 4:6, missing])
gdf = groupby_checked(DataFrame(g=[1, 1, 1], x=col), :g)
@test_throws MethodError combine(gdf, :x => fun => :y)
@test_throws MethodError combine(gdf, :x => (x -> fun(x)) => :y)
end

for fun in (sum, prod, mean, var, std, sumskipmissing, prodskipmissing,
meanskipmissing, varskipmissing, stdskipmissing,
maximum, minimum, maximumskipmissing, minimumskipmissing,
first, last, length, firstskipmissing, lastskipmissing),
col in ([ones(2,2), zeros(2,2), ones(2,2)], [ones(2,2), zeros(2,2), missing],
[DataFrame(ones(2,2)), DataFrame(zeros(2,2)), DataFrame(ones(2,2))],
[DataFrame(ones(2,2)), DataFrame(zeros(2,2)), ones(2,2)],
[DataFrame(ones(2,2)), DataFrame(zeros(2,2)), missing],
[(a=1, b=2), (a=3, b=4), (a=5, b=6)], [(a=1, b=2), (a=3, b=4), missing])
gdf = groupby_checked(DataFrame(g=[1, 1, 1], x=col), :g)
if fun === length
@test isequal_coltyped(combine(gdf, :x => fun => :y), DataFrame(g=1, y=3))
@test isequal_coltyped(combine(gdf, :x => (x -> fun(x)) => :y), DataFrame(g=1, y=3))
elseif (fun === last && ismissing(last(col))) ||
(fun in (maximum, minimum) && col [(a=1, b=2), (a=3, b=4), missing])
# this case is a situation when the vector type would not be accepted in
# general as it contains entries that we do not allow but accidentally
# its last element is accepted because it is missing
@test isequal_coltyped(combine(gdf, :x => fun => :y), DataFrame(g=1, y=missing))
@test isequal_coltyped(combine(gdf, :x => (x -> fun(x)) => :y), DataFrame(g=1, y=missing))
else
@test_throws Union{ArgumentError, MethodError} combine(gdf, :x => fun => :y)
@test_throws Union{ArgumentError, MethodError} combine(gdf, :x => (x -> fun(x)) => :y)
end
end
end

end # module

0 comments on commit 4c601bc

Please sign in to comment.