|
3 | 3 | xlength = 3
|
4 | 4 | zlength = 2
|
5 | 5 | batchsize = 10
|
6 |
| - m = SplitLayer(zlength, [xlength,xlength], [identity,abs]) |
7 |
| - p = ConditionalMvNormal(m) |> gpu |
8 | 6 |
|
9 |
| - # MvNormal |
10 |
| - res = condition(p, rand(zlength) |> gpu) |
11 |
| - μ = mean(res) |
12 |
| - σ2 = var(res) |
13 |
| - @test res isa TuringDiagMvNormal |
14 |
| - @test size(μ) == (xlength,) |
15 |
| - @test size(σ2) == (xlength,) |
| 7 | + σvector(x::AbstractVector) = ones(Float32,xlength) .* 3 |
| 8 | + σvector(x::AbstractMatrix) = ones(Float32,xlength,size(x,2)) .* 3 |
| 9 | + σscalar(x::AbstractVector) = 2 |
| 10 | + σscalar(x::AbstractMatrix) = ones(Float32,size(x,2)) .* 2 |
16 | 11 |
|
17 | 12 | x = rand(Float32, xlength) |> gpu
|
18 | 13 | z = rand(Float32, zlength) |> gpu
|
19 |
| - loss() = logpdf(p,x,z) |
20 |
| - ps = Flux.params(p) |
21 |
| - @test_broken loss() isa Float32 |
22 |
| - @test_nowarn Flux.gradient(loss, ps) |
23 |
| - |
24 |
| - f() = sum(rand(p,z)) |
25 |
| - @test_broken Flux.gradient(f, ps) |
26 |
| - |
27 |
| - # BatchDiagMvNormal |
28 |
| - res = condition(p, rand(zlength,batchsize)|>gpu) |
29 |
| - μ = mean(res) |
30 |
| - σ2 = var(res) |
31 |
| - @test res isa ConditionalDists.BatchDiagMvNormal |
32 |
| - @test size(μ) == (xlength,batchsize) |
33 |
| - @test size(σ2) == (xlength,batchsize) |
34 |
| - |
35 |
| - x = rand(Float32, xlength, batchsize) |> gpu |
36 |
| - z = rand(Float32, zlength, batchsize) |> gpu |
37 |
| - loss() = sum(logpdf(p,x,z)) |
38 |
| - ps = Flux.params(p) |
39 |
| - @test length(ps) == 4 |
40 |
| - @test loss() isa Float32 |
41 |
| - @test_nowarn gs = Flux.gradient(loss, ps) |
42 |
| - |
43 |
| - f() = sum(rand(p,z)) |
44 |
| - @test_nowarn Flux.gradient(f, ps) |
45 |
| - |
46 |
| - |
47 |
| - # BatchScalMvNormal |
48 |
| - m = SplitLayer(zlength, [xlength,1]) |
49 |
| - p = ConditionalMvNormal(m) |> gpu |
50 |
| - |
51 |
| - res = condition(p, rand(zlength,batchsize)|>gpu) |
52 |
| - μ = mean(res) |
53 |
| - σ2 = var(res) |
54 |
| - @test res isa ConditionalDists.BatchScalMvNormal |
55 |
| - @test size(μ) == (xlength,batchsize) |
56 |
| - @test size(σ2) == (xlength,batchsize) |
57 |
| - |
58 |
| - x = rand(Float32, xlength, batchsize) |> gpu |
59 |
| - z = rand(Float32, zlength, batchsize) |> gpu |
60 |
| - loss() = sum(logpdf(p,x,z)) |
61 |
| - ps = Flux.params(p) |
62 |
| - @test length(ps) == 4 |
63 |
| - @test loss() isa Float32 |
64 |
| - @test_nowarn gs = Flux.gradient(loss, ps) |
65 |
| - |
66 |
| - f() = sum(rand(p,z)) |
67 |
| - @test_nowarn Flux.gradient(f, ps) |
68 |
| - |
69 |
| - |
70 |
| - # Unit variance |
71 |
| - m = Dense(zlength,xlength) |
72 |
| - p = ConditionalMvNormal(m) |> gpu |
73 |
| - |
74 |
| - res = condition(p, rand(zlength,batchsize)|>gpu) |
75 |
| - μ = mean(res) |
76 |
| - σ2 = var(res) |
77 |
| - @test res isa ConditionalDists.BatchScalMvNormal |
78 |
| - @test size(μ) == (xlength,batchsize) |
79 |
| - @test size(σ2) == (xlength,batchsize) |
80 |
| - |
81 |
| - x = rand(Float32, xlength, batchsize) |> gpu |
82 |
| - z = rand(Float32, zlength, batchsize) |> gpu |
83 |
| - loss() = sum(logpdf(p,x,z)) |
84 |
| - ps = Flux.params(p) |
85 |
| - @test length(ps) == 2 |
86 |
| - @test loss() isa Float32 |
87 |
| - @test_nowarn gs = Flux.gradient(loss, ps) |
88 |
| - |
89 |
| - f() = sum(rand(p,z)) |
90 |
| - @test_nowarn Flux.gradient(f, ps) |
91 |
| - |
92 |
| - |
93 |
| - # Fixed scalar variance |
94 |
| - m = Dense(zlength,xlength) |
95 |
| - σ(x::AbstractVector) = 2 |
96 |
| - σ(x::AbstractMatrix) = ones(Float32,size(x,2)) .* 2 |
97 |
| - p = ConditionalMvNormal(SplitLayer(m,σ)) |> gpu |
98 |
| - |
99 |
| - res = condition(p, rand(zlength,batchsize)|>gpu) |
100 |
| - μ = mean(res) |
101 |
| - σ2 = var(res) |
102 |
| - @test res isa ConditionalDists.BatchScalMvNormal |
103 |
| - @test size(μ) == (xlength,batchsize) |
104 |
| - @test size(σ2) == (xlength,batchsize) |
105 |
| - |
106 |
| - x = rand(Float32, xlength, batchsize) |> gpu |
107 |
| - z = rand(Float32, zlength, batchsize) |> gpu |
108 |
| - loss() = sum(logpdf(p,x,z)) |
109 |
| - ps = Flux.params(p) |
110 |
| - @test length(ps) == 2 |
111 |
| - @test loss() isa Float32 |
112 |
| - @test_nowarn gs = Flux.gradient(loss, ps) |
113 |
| - |
114 |
| - f() = sum(rand(p,z)) |
115 |
| - @test_nowarn Flux.gradient(f, ps) |
116 |
| - |
| 14 | + X = rand(Float32, xlength, batchsize) |> gpu |
| 15 | + Z = rand(Float32, zlength, batchsize) |> gpu |
| 16 | + |
| 17 | + cases = [ |
| 18 | + ("vector μ / vector σ", |
| 19 | + SplitLayer(zlength, [xlength,xlength], [identity,abs]), Vector, 4), |
| 20 | + ("vector μ / scalar σ", |
| 21 | + SplitLayer(zlength, [xlength,1], [identity,abs]), Real, 4), |
| 22 | + ("vector μ / fixed vector σ", |
| 23 | + SplitLayer(Dense(zlength,xlength), σvector), Vector, 2), |
| 24 | + ("vector μ / fixed scalar σ", |
| 25 | + SplitLayer(Dense(zlength,xlength), σscalar), Real, 2), |
| 26 | + ("vector μ / unit σ", |
| 27 | + Dense(zlength,xlength), Real, 2), |
| 28 | + ("vector μ / shared, trainable vector σ", |
| 29 | + SplitLayer(Dense(zlength,xlength), ones(Float32,xlength)), Vector, 3), |
| 30 | + ("vector μ / shared, trainable scalar σ", |
| 31 | + SplitLayer(Dense(zlength,xlength), 1f0), Real, 3) |
| 32 | + ] |
| 33 | + |
| 34 | + disttypes(::Type{<:Vector}) = (TuringDiagMvNormal,ConditionalDists.BatchDiagMvNormal) |
| 35 | + disttypes(::Type{<:Real}) = (TuringScalMvNormal,ConditionalDists.BatchScalMvNormal) |
| 36 | + σsize(::Type{<:Vector}) = (xlength,) |
| 37 | + σsize(::Type{<:Real}) = () |
| 38 | + |
| 39 | + |
| 40 | + for (name,mapping,T,nrps) in cases |
| 41 | + @testset "$name" begin |
| 42 | + p = ConditionalMvNormal(mapping) |> gpu |
| 43 | + (Texample,Tbatch) = disttypes(T) |
| 44 | + |
| 45 | + res = condition(p,z) |
| 46 | + μ = mean(res) |
| 47 | + σ2 = var(res) |
| 48 | + @test res isa Texample |
| 49 | + @test size(μ) == (xlength,) |
| 50 | + @test size(σ2) == σsize(T) |
| 51 | + |
| 52 | + loss() = logpdf(p,x,z) |
| 53 | + ps = Flux.params(p) |
| 54 | + @test length(ps) == nrps |
| 55 | + @test loss() isa Float32 |
| 56 | + @test_nowarn Flux.gradient(loss, ps) |
| 57 | + |
| 58 | + f() = sum(rand(p,z)) |
| 59 | + gs = Flux.gradient(f,ps) |
| 60 | + for p in ps |
| 61 | + g = gs[p] |
| 62 | + @test all(isfinite.(g)) && all(g .!= 0) |
| 63 | + end |
| 64 | + |
| 65 | + |
| 66 | + # batch tests |
| 67 | + res = condition(p,Z) |
| 68 | + μ = mean(res) |
| 69 | + σ2 = var(res) |
| 70 | + @test res isa Tbatch |
| 71 | + @test size(μ) == (xlength,batchsize) |
| 72 | + @test size(σ2) == (xlength,batchsize) |
| 73 | + |
| 74 | + loss() = sum(logpdf(p,X,Z)) |
| 75 | + @test loss() isa Float32 |
| 76 | + @test_nowarn Flux.gradient(loss, ps) |
| 77 | + |
| 78 | + f() = sum(rand(p,Z)) |
| 79 | + gs = Flux.gradient(f,ps) |
| 80 | + for p in ps |
| 81 | + g = gs[p] |
| 82 | + @test all(isfinite.(g)) && all(g .!= 0) |
| 83 | + end |
| 84 | + end |
| 85 | + end |
117 | 86 | end
|
0 commit comments