Skip to content

Commit 0e64e9c

Browse files
authored
Export SplitLayer (#41)
also fix constructors and add tests
1 parent 5dce6c6 commit 0e64e9c

File tree

5 files changed

+23
-7
lines changed

5 files changed

+23
-7
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ConditionalDists"
22
uuid = "c648c4dd-c1e0-49a6-84b9-144ae7fd2468"
33
authors = ["Niklas Heim <niklas.heim@aic.fel.cvut.cz>"]
4-
version = "0.4.4"
4+
version = "0.4.5"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/ConditionalDists.jl

+6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export condition
1212

1313
export ConditionalDistribution
1414
export ConditionalMvNormal
15+
export SplitLayer
1516

1617
include("cond_dist.jl")
1718

@@ -25,6 +26,11 @@ function __init__()
2526
function SplitLayer(in::Int, outs::Vector{Int}, acts::Vector)
2627
SplitLayer(Tuple(Dense(in,o,a) for (o,a) in zip(outs,acts)))
2728
end
29+
30+
function SplitLayer(in::Int, outs::Vector{Int}, act=identity)
31+
acts = [act for _ in 1:length(outs)]
32+
SplitLayer(in, outs, acts)
33+
end
2834
end
2935
end
3036

src/utils.jl

+3-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
struct SplitLayer
2-
layers::Tuple
1+
struct SplitLayer{T<:Tuple}
2+
layers::T
33
end
44

5-
function SplitLayer(in::Int, outs::Vector{Int}, act=identity)
6-
acts = [act for _ in 1:length(outs)]
7-
SplitLayer(in, outs, acts)
8-
end
5+
SplitLayer(xs...) = SplitLayer(xs)
96

107
function (m::SplitLayer)(x)
118
Tuple(layer(x) for layer in m.layers)

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using ConditionalDists: BatchMvNormal, SplitLayer
1111

1212
include("cond_dist.jl")
1313
include("cond_mvnormal.jl")
14+
include("utils.jl")
1415

1516
# for the BatchMvNormal tests to work BatchMvNormals have to be functors!
1617
include("batch_mvnormal.jl")

test/utils.jl

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
@testset "SplitLayer" begin
2+
l = SplitLayer(x->x .+ 1, _->1)
3+
x = rand(3)
4+
(a,b) = l(x)
5+
@test all(a .≈ x .+ 1)
6+
@test b == 1
7+
8+
l = SplitLayer(3,[2,4])
9+
(a,b) = l(x)
10+
@test size(a) == (2,)
11+
@test size(b) == (4,)
12+
end

0 commit comments

Comments
 (0)