-
Notifications
You must be signed in to change notification settings - Fork 193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix type instability of entropy
and generalize crossentropy
and kldivergence
#714
Conversation
Thanks, that indeed simplifies the code a lot. Have you checked that performance doesn't regress? Also, does it still work for empty inputs? |
With the latest release: julia> using StatsBase, BenchmarkTools, Random
julia> Random.seed!(1234);
julia> p, q = rand(100_000), rand(100_000);
julia> @btime entropy($p);
753.493 μs (0 allocations: 0 bytes)
julia> @btime crossentropy($p, $q);
779.563 μs (0 allocations: 0 bytes)
julia> @btime kldivergence($p, $q);
1.540 ms (0 allocations: 0 bytes)
julia> entropy(Float64[])
ERROR: ArgumentError: reducing over an empty collection is not allowed
Stacktrace:
...
julia> crossentropy(Float64[], Float64[])
-0.0
julia> kldivergence(Float64[], Float64[])
0.0 With this PR: julia> using StatsBase, BenchmarkTools, Random
julia> Random.seed!(1234);
julia> p, q = rand(100_000), rand(100_000);
julia> @btime entropy($p);
837.509 μs (0 allocations: 0 bytes)
julia> @btime crossentropy($p, $q);
1.133 ms (0 allocations: 0 bytes)
julia> @btime kldivergence($p, $q);
1.680 ms (0 allocations: 0 bytes)
julia> entropy(Float64[])
ERROR: ArgumentError: reducing over an empty collection is not allowed
Stacktrace:
...
julia> crossentropy(Float64[], Float64[])
ERROR: ArgumentError: reducing over an empty collection is not allowed
Stacktrace:
...
julia> kldivergence(Float64[], Float64[])
ERROR: ArgumentError: reducing over an empty collection is not allowed
Stacktrace:
... So there seems to be some performance regression and empty inputs don't work but also don't work with |
On a second thought, it is not completely clear to me anymore if it is reasonable to define |
Given that these are sums, isn't it logical to return zero for empty inputs? Anyway, throwing an error instead would be breaking so we would at least need to keep returning zero, possibly with a deprecation warning. |
My main motivation for throwing an error would be that empty vectors don't represent probability distributions. |
So what's the plan here? Return a typestable 0 for empty inputs also for |
I find that convincing. |
As you prefer, but I just think we shouldn't add new errors for now to avoid being breaking. A deprecation warning would be OK, with a PR to turn them into errors in the next breaking release. BTW should we check that the values sum to one if we want to ensure they are probability distributions (with a way to skip the check)? |
src/scalarstats.jl
Outdated
end | ||
end | ||
return -s | ||
return - sum(xlogy(pi, qi) for (pi, qi) in zip(p, q)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that using lazy broadcasting (on recent julia versions) is more efficient and has better precision as it uses pairwise summation. Maybe that would fix the performance regression?
return - sum(Broadcast.instantiate(Broadcast.broadcasted(xlogx, p, q)))
EDIT: probably need to use vec
to keep the same behavior as currently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've seen the Julia PR and played around with sum(Brodcast....)
in some local branch of Distributions but in my benchmarks there another implementation was actually faster. I don't remember the details right now and I assume it was not really comparable to this example here though. I'll check if it helps but I guess we have to live with a slight performance regression - the existing implementations are not type stable and less general, so it is not completely fair to compare the performance of the bugfixes in this PR to the performance of the incorrect existing implementation (of course, we should still try to improve the performance of the bugfixes as much as possible).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improved precision may still be valuable anyway.
I reran the benchmarks with Julia 1.6.3 (with With StatsBase#master: julia> using StatsBase, BenchmarkTools, Random, Test
julia> Random.seed!(1234);
julia> p, q = rand(100_000), rand(100_000);
julia> @btime entropy($p);
721.254 μs (0 allocations: 0 bytes)
julia> @btime crossentropy($p, $q);
765.894 μs (0 allocations: 0 bytes)
julia> @btime kldivergence($p, $q);
1.440 ms (0 allocations: 0 bytes)
julia> pbig, qbig = rand(BigFloat, 100_000), rand(BigFloat, 100_000);
julia> @btime entropy($pbig);
616.253 ms (700184 allocations: 35.11 MiB)
julia> @btime crossentropy($pbig, $qbig);
612.427 ms (700234 allocations: 35.11 MiB)
julia> @btime kldivergence($pbig, $qbig);
633.777 ms (900163 allocations: 45.79 MiB)
julia> @inferred entropy(pbig);
julia> @inferred crossentropy(pbig, qbig);
ERROR: return type BigFloat does not match inferred return type Union{Float64, BigFloat}
julia> @inferred kldivergence(pbig, qbig);
ERROR: return type BigFloat does not match inferred return type Union{Float64, BigFloat} With this PR (commit 97d3bfa): julia> using StatsBase, BenchmarkTools, Random, Test
julia> Random.seed!(1234);
julia> p, q = rand(100_000), rand(100_000);
julia> @btime entropy($p);
763.058 μs (0 allocations: 0 bytes)
julia> @btime crossentropy($p, $q);
1.100 ms (0 allocations: 0 bytes)
julia> @btime kldivergence($p, $q);
1.568 ms (0 allocations: 0 bytes)
julia> pbig, qbig = rand(BigFloat, 100_000), rand(BigFloat, 100_000);
julia> @btime entropy($pbig);
613.844 ms (700184 allocations: 35.11 MiB)
julia> @btime crossentropy($pbig, $qbig);
597.341 ms (700230 allocations: 35.11 MiB)
julia> @btime kldivergence($pbig, $qbig);
636.595 ms (900159 allocations: 45.79 MiB)
julia> @inferred entropy(pbig);
julia> @inferred crossentropy(pbig, qbig);
julia> @inferred kldivergence(pbig, qbig); With julia> using StatsBase, BenchmarkTools, Random, Test
julia> Random.seed!(1234);
julia> p, q = rand(100_000), rand(100_000);
julia> @btime entropy($p);
763.252 μs (0 allocations: 0 bytes)
julia> @btime crossentropy($p, $q);
882.439 μs (0 allocations: 0 bytes)
julia> @btime kldivergence($p, $q);
1.585 ms (0 allocations: 0 bytes)
julia> pbig, qbig = rand(BigFloat, 100_000), rand(BigFloat, 100_000);
julia> @btime entropy($pbig);
599.742 ms (700184 allocations: 35.11 MiB)
julia> @btime crossentropy($pbig, $qbig);
600.415 ms (700230 allocations: 35.11 MiB)
julia> @btime kldivergence($pbig, $qbig);
621.025 ms (900161 allocations: 45.79 MiB)
julia> @inferred entropy(pbig);
julia> @inferred crossentropy(pbig, qbig);
julia> @inferred kldivergence(pbig, qbig); |
I updated the PR, it handles (and deprecates) empty collections of probabilities and uses pairwise summation now for |
src/scalarstats.jl
Outdated
entropy(p) = -sum(pᵢ -> iszero(pᵢ) ? zero(pᵢ) : pᵢ * log(pᵢ), p) | ||
function entropy(p) | ||
if isempty(p) | ||
throw(ArgumentError("empty collections of probabilities are not supported")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe more explicit?
throw(ArgumentError("empty collections of probabilities are not supported")) | |
throw(ArgumentError("empty collections are not supported as they do not " * | |
"represent a proper probability distribution"")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Co-authored-by: Milan Bouchet-Valat <nalimilan@club.fr>
Any additional comments or suggestions? |
While I was working on a PR to Distributions, I noticed that
entropy
is not type stable:This PR fixes the issue by using
LogExpFunctions.xlogx
instead ofx -> x > 0 ? zero(x) : z * log(x)
. Moreover, it basescrossentropy
andkldivergence
onLogExpFunctions.xlogy
and generalizes them to non-Float64 return types and arguments with different element types.