Skip to content

Commit 5d0c37b

Browse files
authored
Shared variance via SplitLayer (#43)
implement learned/shared/fixed/unit variance and properly test all cases
1 parent e0cd6bc commit 5d0c37b

File tree

6 files changed

+219
-124
lines changed

6 files changed

+219
-124
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ConditionalDists"
22
uuid = "c648c4dd-c1e0-49a6-84b9-144ae7fd2468"
33
authors = ["Niklas Heim <niklas.heim@aic.fel.cvut.cz>"]
4-
version = "0.4.5"
4+
version = "0.4.6"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

README.md

+10-6
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33

44
# ConditionalDists.jl
55

6-
Conditional probability distributions powered by Flux.jl and Distributions.jl.
6+
Conditional probability distributions powered by Flux.jl and DistributionsAD.jl.
77

88
The conditional PDFs that are defined in this package can be used in
99
conjunction with Flux models to provide trainable mappings. As an example,
1010
assume you want to learn the mapping from a conditional to an MvNormal. The
1111
mapping `m` takes a vector `x` and maps it to a mean `μ` and a variance `σ`,
1212
which can be achieved by using a `ConditionalDists.SplitLayer` as the last
13-
layer in your network like the one below: The `SplitLayer` is constructed from
14-
`N` `Dense` layers (with same input size) and outputs `N` vectors:
13+
layer in the network.
1514
```julia
16-
julia> m = SplitLayer(2, [3,4])
15+
julia> m = Chain(Dense(2,2,σ), SplitLayer(2, [3,4]))
1716
julia> m(rand(2))
1817
(Float32[0.07946974, 0.13797458, 0.03939067], Float32[0.7006321, 0.37641272, 0.3586885, 0.82230335])
1918
```
@@ -40,8 +39,13 @@ julia> z = rand(zlength, batchsize)
4039
julia> logpdf(p,x,z)
4140
julia> rand(p, randn(zlength, 10))
4241
```
43-
The trainable parameters (of the `SplitLayer`) are accessible as usual
44-
through `Flux.params`. The next few lines show how to optimize `p` to match a
42+
The trainable parameters (of the `SplitLayer`) are accessible as usual through
43+
`Flux.params`. For different variance configurations (i.e. fixed/unit variance,
44+
etc) check the doc strings with `julia>? ConditionalMvNormal`/`julia>?
45+
SplitLayer`.
46+
47+
48+
The next few lines show how to optimize `p` to match a
4549
given Gaussian by using the `kl_divergence` defined in
4650
[IPMeasures.jl](https://github.com/aicenter/IPMeasures.jl).
4751

src/cond_mvnormal.jl

+49-9
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
"""
22
ConditionalMvNormal(m)
33
4-
Specialization of ConditionalDistribution for `MvNormal`s for performance.
5-
Does the same as ConditionalDistribution(MvNormal,m) for vector inputs (to e.g.
6-
mean/logpdf). For batches of inputs a `BatchMvNormal` is constructed that does
4+
Specialization of ConditionalDistribution for `MvNormal`s (for performance with
5+
batches of inputs). Does the same as ConditionalDistribution(MvNormal,m)
6+
but for batches of inputs a `BatchMvNormal` is constructed that does
77
not just map over the batch but uses faster matrix multiplications.
88
9-
The mapping `m` must return either a `Tuple` with mean and variance, or just a
10-
mean vector. If the output of `m` is just a vector, the variance is assumed to
11-
be a fixed unit variance.
12-
13-
# Examples
149
```julia-repl
15-
julia> m = ConditionalDists.SplitLayer(100,[100,100])
10+
julia> m = SplitLayer(100,[100,100])
1611
julia> p = ConditionalMvNormal(m)
1712
julia> @time rand(p, rand(100,10000);
1813
julia> @time rand(p, rand(100,10000);
@@ -26,6 +21,51 @@ julia> @time rand(p, rand(100,10000);
2621
3.626042 seconds (159.97 k allocations: 18.681 GiB, 34.92% gc time)
2722
```
2823
24+
The mapping `m` must return a `Tuple` with mean and variance.
25+
For a convenient way of doing this you can use a `SplitLayer`.
26+
27+
28+
# Examples
29+
30+
`ConditionalMvNormal` and `SplitLayer` together support 3 different variance
31+
configurations: fixed/unit variance, shared variance, and trained variance. The
32+
three different configurations are explained below.
33+
34+
## Fixed/unit variance
35+
36+
Pass a function to the `SplitLayer` that returns the fixed variance with
37+
appropriate batch dimensions
38+
```julia-repl
39+
julia> σ(x::Vector) = 2
40+
julia> σ(x::Matrix) = ones(Float32,size(x,2)) .* 2
41+
julia> m = SplitLayer(Dense(2,3), σ)
42+
julia> p = ConditionalMvNormal(m)
43+
julia> condition(p,rand(Float32,2)) isa DistributionsAD.TuringScalMvNormal
44+
```
45+
Passing a mapping with a single output array assumes unit variance.
46+
47+
## Shared variance
48+
49+
For a learned variance that is the same across the the whole batch, simply pass
50+
a vector (or scalar) to the `SplitLayer`. The `SplitLayer` wraps vectors/scalars
51+
into a `TrainableVector`s/`TrainableScalar`s.
52+
```julia-repl
53+
julia> m = SplitLayer(Dense(2,3), ones(Float32,3))
54+
julia> p = ConditionalMvNormal(m)
55+
julia> condition(p,rand(Float32,2)) isa DistributionsAD.TuringDiagMvNormal
56+
```
57+
58+
## Trained variance
59+
60+
Simply pass another trainable mapping for the variance. By just supplying input
61+
sizes to `SplitLayer` you can automatically create `Dense` layers with given
62+
activation functions. In this example the second activation function makes sure
63+
that the variance is always positive
64+
```julia-repl
65+
julia> m = SplitLayer(2,[3,1],[identity,abs])
66+
julia> p = ConditionalMvNormal(m)
67+
julia> condition(p,rand(Float32,2)) isa DistributionsAD.TuringScalMvNormal
68+
```
2969
"""
3070
struct ConditionalMvNormal{Tm} <: AbstractConditionalDistribution
3171
mapping::Tm

src/utils.jl

+62-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,76 @@
1+
"""
2+
SplitLayer(xs...)
3+
4+
A layer that calls a number of sublayers/mappings with the same input and
5+
returns a tuple of their outputs. Can be used in a regular Flux model:
6+
7+
```julia-repl
8+
julia> m = Chain(Dense(2,3), SplitLayer(Dense(3,2), x->x .* 2))
9+
julia> length(params(m)) == 4
10+
julia> (x,y) = m(rand(2))
11+
(Float32[-1.0541434, 1.1694773], Float32[-3.1472511, -0.86115724, -0.39665926])
12+
```
13+
Comes with a convenient constructor for a SplitLayer built from Dense layers
14+
with given activation(s):
15+
```julia-repl
16+
julia> m = Chain(Dense(2,3), SplitLayer(3, [2,5], σ))
17+
julia> (x,y) = m(rand(2))
18+
(Float32[0.3069554, 0.3362006], Float32[0.437131, 0.4982477, 0.6465078, 0.4523438, 0.5068563])
19+
```
20+
21+
You can also provide just a vector / scalar that should be trained but have the
22+
same value for all inputs (like a lonely bias vector). This functionality is
23+
provided by the `TrainableVector`/`TrainableScalar` types. For vector inputs
24+
they simply return the array they are wrapping. For matrix (i.e. batch) inputs
25+
they return appropriately repeated arrays:
26+
```julia-repl
27+
julia> m = SplitLayer(Dense(2,3), ones(Float32,3))
28+
julia> length(params(m)) == 3
29+
julia> (x,y) = m(rand(2,5))
30+
julia> size(y) == (3,5)
31+
julia> y
32+
3×3 Array{Float32,2}:
33+
1.0 1.0 1.0
34+
1.0 1.0 1.0
35+
1.0 1.0 1.0
36+
```
37+
"""
138
struct SplitLayer{T<:Tuple}
239
layers::T
340
end
441

5-
SplitLayer(xs...) = SplitLayer(xs)
42+
SplitLayer(layers...) = SplitLayer(map(maybe_trainable, layers))
643

744
function (m::SplitLayer)(x)
845
Tuple(layer(x) for layer in m.layers)
946
end
1047

1148
@functor SplitLayer
1249

50+
51+
# for use as e.g. shared variance
52+
struct TrainableVector{T<:AbstractArray}
53+
v::T
54+
end
55+
(v::TrainableVector)(x::AbstractVector) = v.v
56+
(v::TrainableVector)(x::AbstractMatrix) = v.v .* reshape(fillsimilar(v.v,size(x,ndims(x)),1),1,:)
57+
(v::TrainableVector)() = v.v
58+
@functor TrainableVector
59+
60+
# for use as e.g. shared variance
61+
struct TrainableScalar{T<:Real}
62+
s::AbstractVector{T}
63+
TrainableScalar{T}(x::T) where T<:Real = new{T}([x])
64+
end
65+
TrainableScalar(x::T) where T<:Real = TrainableScalar{T}(x)
66+
(s::TrainableScalar)(x::AbstractVector) = s.s[1]
67+
(s::TrainableScalar)(x::AbstractMatrix) = fillsimilar(x,size(x,ndims(x)),1) .* s.s
68+
@functor TrainableScalar
69+
70+
maybe_trainable(x) = x
71+
maybe_trainable(x::AbstractArray) = TrainableVector(x)
72+
maybe_trainable(x::Real) = TrainableScalar(x)
73+
1374
fillsimilar(x::AbstractArray, s::Tuple, value::Real) = fill!(similar(x, s...), value)
1475
fillsimilar(x::AbstractArray, s, value::Real) = fill!(similar(x, s), value)
1576
@non_differentiable fillsimilar(::Any, ::Any, ::Any)

test/cond_mvnormal.jl

+76-107
Original file line numberDiff line numberDiff line change
@@ -3,115 +3,84 @@
33
xlength = 3
44
zlength = 2
55
batchsize = 10
6-
m = SplitLayer(zlength, [xlength,xlength], [identity,abs])
7-
p = ConditionalMvNormal(m) |> gpu
86

9-
# MvNormal
10-
res = condition(p, rand(zlength) |> gpu)
11-
μ = mean(res)
12-
σ2 = var(res)
13-
@test res isa TuringDiagMvNormal
14-
@test size(μ) == (xlength,)
15-
@test size(σ2) == (xlength,)
7+
σvector(x::AbstractVector) = ones(Float32,xlength) .* 3
8+
σvector(x::AbstractMatrix) = ones(Float32,xlength,size(x,2)) .* 3
9+
σscalar(x::AbstractVector) = 2
10+
σscalar(x::AbstractMatrix) = ones(Float32,size(x,2)) .* 2
1611

1712
x = rand(Float32, xlength) |> gpu
1813
z = rand(Float32, zlength) |> gpu
19-
loss() = logpdf(p,x,z)
20-
ps = Flux.params(p)
21-
@test_broken loss() isa Float32
22-
@test_nowarn Flux.gradient(loss, ps)
23-
24-
f() = sum(rand(p,z))
25-
@test_broken Flux.gradient(f, ps)
26-
27-
# BatchDiagMvNormal
28-
res = condition(p, rand(zlength,batchsize)|>gpu)
29-
μ = mean(res)
30-
σ2 = var(res)
31-
@test res isa ConditionalDists.BatchDiagMvNormal
32-
@test size(μ) == (xlength,batchsize)
33-
@test size(σ2) == (xlength,batchsize)
34-
35-
x = rand(Float32, xlength, batchsize) |> gpu
36-
z = rand(Float32, zlength, batchsize) |> gpu
37-
loss() = sum(logpdf(p,x,z))
38-
ps = Flux.params(p)
39-
@test length(ps) == 4
40-
@test loss() isa Float32
41-
@test_nowarn gs = Flux.gradient(loss, ps)
42-
43-
f() = sum(rand(p,z))
44-
@test_nowarn Flux.gradient(f, ps)
45-
46-
47-
# BatchScalMvNormal
48-
m = SplitLayer(zlength, [xlength,1])
49-
p = ConditionalMvNormal(m) |> gpu
50-
51-
res = condition(p, rand(zlength,batchsize)|>gpu)
52-
μ = mean(res)
53-
σ2 = var(res)
54-
@test res isa ConditionalDists.BatchScalMvNormal
55-
@test size(μ) == (xlength,batchsize)
56-
@test size(σ2) == (xlength,batchsize)
57-
58-
x = rand(Float32, xlength, batchsize) |> gpu
59-
z = rand(Float32, zlength, batchsize) |> gpu
60-
loss() = sum(logpdf(p,x,z))
61-
ps = Flux.params(p)
62-
@test length(ps) == 4
63-
@test loss() isa Float32
64-
@test_nowarn gs = Flux.gradient(loss, ps)
65-
66-
f() = sum(rand(p,z))
67-
@test_nowarn Flux.gradient(f, ps)
68-
69-
70-
# Unit variance
71-
m = Dense(zlength,xlength)
72-
p = ConditionalMvNormal(m) |> gpu
73-
74-
res = condition(p, rand(zlength,batchsize)|>gpu)
75-
μ = mean(res)
76-
σ2 = var(res)
77-
@test res isa ConditionalDists.BatchScalMvNormal
78-
@test size(μ) == (xlength,batchsize)
79-
@test size(σ2) == (xlength,batchsize)
80-
81-
x = rand(Float32, xlength, batchsize) |> gpu
82-
z = rand(Float32, zlength, batchsize) |> gpu
83-
loss() = sum(logpdf(p,x,z))
84-
ps = Flux.params(p)
85-
@test length(ps) == 2
86-
@test loss() isa Float32
87-
@test_nowarn gs = Flux.gradient(loss, ps)
88-
89-
f() = sum(rand(p,z))
90-
@test_nowarn Flux.gradient(f, ps)
91-
92-
93-
# Fixed scalar variance
94-
m = Dense(zlength,xlength)
95-
σ(x::AbstractVector) = 2
96-
σ(x::AbstractMatrix) = ones(Float32,size(x,2)) .* 2
97-
p = ConditionalMvNormal(SplitLayer(m,σ)) |> gpu
98-
99-
res = condition(p, rand(zlength,batchsize)|>gpu)
100-
μ = mean(res)
101-
σ2 = var(res)
102-
@test res isa ConditionalDists.BatchScalMvNormal
103-
@test size(μ) == (xlength,batchsize)
104-
@test size(σ2) == (xlength,batchsize)
105-
106-
x = rand(Float32, xlength, batchsize) |> gpu
107-
z = rand(Float32, zlength, batchsize) |> gpu
108-
loss() = sum(logpdf(p,x,z))
109-
ps = Flux.params(p)
110-
@test length(ps) == 2
111-
@test loss() isa Float32
112-
@test_nowarn gs = Flux.gradient(loss, ps)
113-
114-
f() = sum(rand(p,z))
115-
@test_nowarn Flux.gradient(f, ps)
116-
14+
X = rand(Float32, xlength, batchsize) |> gpu
15+
Z = rand(Float32, zlength, batchsize) |> gpu
16+
17+
cases = [
18+
("vector μ / vector σ",
19+
SplitLayer(zlength, [xlength,xlength], [identity,abs]), Vector, 4),
20+
("vector μ / scalar σ",
21+
SplitLayer(zlength, [xlength,1], [identity,abs]), Real, 4),
22+
("vector μ / fixed vector σ",
23+
SplitLayer(Dense(zlength,xlength), σvector), Vector, 2),
24+
("vector μ / fixed scalar σ",
25+
SplitLayer(Dense(zlength,xlength), σscalar), Real, 2),
26+
("vector μ / unit σ",
27+
Dense(zlength,xlength), Real, 2),
28+
("vector μ / shared, trainable vector σ",
29+
SplitLayer(Dense(zlength,xlength), ones(Float32,xlength)), Vector, 3),
30+
("vector μ / shared, trainable scalar σ",
31+
SplitLayer(Dense(zlength,xlength), 1f0), Real, 3)
32+
]
33+
34+
disttypes(::Type{<:Vector}) = (TuringDiagMvNormal,ConditionalDists.BatchDiagMvNormal)
35+
disttypes(::Type{<:Real}) = (TuringScalMvNormal,ConditionalDists.BatchScalMvNormal)
36+
σsize(::Type{<:Vector}) = (xlength,)
37+
σsize(::Type{<:Real}) = ()
38+
39+
40+
for (name,mapping,T,nrps) in cases
41+
@testset "$name" begin
42+
p = ConditionalMvNormal(mapping) |> gpu
43+
(Texample,Tbatch) = disttypes(T)
44+
45+
res = condition(p,z)
46+
μ = mean(res)
47+
σ2 = var(res)
48+
@test res isa Texample
49+
@test size(μ) == (xlength,)
50+
@test size(σ2) == σsize(T)
51+
52+
loss() = logpdf(p,x,z)
53+
ps = Flux.params(p)
54+
@test length(ps) == nrps
55+
@test loss() isa Float32
56+
@test_nowarn Flux.gradient(loss, ps)
57+
58+
f() = sum(rand(p,z))
59+
gs = Flux.gradient(f,ps)
60+
for p in ps
61+
g = gs[p]
62+
@test all(isfinite.(g)) && all(g .!= 0)
63+
end
64+
65+
66+
# batch tests
67+
res = condition(p,Z)
68+
μ = mean(res)
69+
σ2 = var(res)
70+
@test res isa Tbatch
71+
@test size(μ) == (xlength,batchsize)
72+
@test size(σ2) == (xlength,batchsize)
73+
74+
loss() = sum(logpdf(p,X,Z))
75+
@test loss() isa Float32
76+
@test_nowarn Flux.gradient(loss, ps)
77+
78+
f() = sum(rand(p,Z))
79+
gs = Flux.gradient(f,ps)
80+
for p in ps
81+
g = gs[p]
82+
@test all(isfinite.(g)) && all(g .!= 0)
83+
end
84+
end
85+
end
11786
end

0 commit comments

Comments
 (0)