Skip to content

Commit 4cae554

Browse files
committed
rescale_weights!
1 parent 9eeb7cf commit 4cae554

File tree

6 files changed

+48
-83
lines changed

6 files changed

+48
-83
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RestrictedBoltzmannMachines"
22
uuid = "12e6b396-7db5-4506-8cb6-664a4fe1e50e"
33
authors = ["Jorge Fernandez-de-Cossio-Diaz <j.cossio.diaz@gmail.com>"]
4-
version = "1.0.0"
4+
version = "2.0.0-DEV"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/from_grad.jl

-42
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,3 @@
22

33
grad2ave(::Union{Binary,Spin,Potts,Gaussian,ReLU,pReLU,xReLU}, ∂::AbstractArray) = ∂[1, ..]
44
grad2ave(::dReLU, ∂::AbstractArray) = ∂[1, ..] + ∂[2, ..]
5-
6-
grad2var(::Union{Binary,Potts}, ∂::AbstractArray) = ∂[1, ..] .* (1 .- ∂[1, ..])
7-
grad2var(::Spin, ∂::AbstractArray) = (1 .- ∂[1, ..]) .* (1 .+ ∂[1, ..])
8-
9-
function grad2var(l::Union{Gaussian,ReLU}, ∂::AbstractArray)
10-
∂θ = @view ∂[1, ..]
11-
∂γ = @view ∂[2, ..]
12-
return -2∂γ .* sign.(l.γ) - ∂θ.^2
13-
end
14-
15-
function grad2var(l::dReLU, ∂::AbstractArray)
16-
∂θp = ∂[1, ..]
17-
∂θn = ∂[2, ..]
18-
∂γp = ∂[3, ..]
19-
∂γn = ∂[4, ..]
20-
return -2 * (∂γp .* sign.(l.γp) + ∂γn .* sign.(l.γn)) - (∂θp + ∂θn).^2
21-
end
22-
23-
function grad2var(l::pReLU, ∂::AbstractArray)
24-
∂θ = -∂[1, ..]
25-
∂γ = -∂[2, ..]
26-
∂Δ = -∂[3, ..]
27-
∂η = -∂[4, ..]
28-
29-
abs_γ = abs.(l.γ)
30-
∂absγ = ∂γ .* sign.(l.γ)
31-
32-
return @. 2l.η/abs_γ * ((2l.Δ * ∂Δ + l.η * ∂η) * l.η - ∂η - l.Δ * ∂θ) + 2∂absγ * (1 + l.η^2) - ∂θ^2
33-
end
34-
35-
function grad2var(l::xReLU, ∂::AbstractArray)
36-
∂θ = -∂[1, ..]
37-
∂γ = -∂[2, ..]
38-
∂Δ = -∂[3, ..]
39-
∂ξ = -∂[4, ..]
40-
41-
abs_γ = abs.(l.γ)
42-
∂absγ = ∂γ .* sign.(l.γ)
43-
44-
ν = @. 2∂absγ - ∂θ^2
45-
return @.* abs_γ - 2 * (∂ξ + ∂θ * l.Δ) * l.ξ + ((ν + 2∂absγ) * abs_γ + 4 * ∂Δ * l.Δ) * l.ξ^2 - 4∂ξ * l.ξ^3 + 2abs(l.ξ) ** abs_γ - 3∂ξ * l.ξ - ∂θ * l.Δ * l.ξ)) / (abs_γ * (1 + abs(l.ξ))^2)
46-
end

src/gauge/rescale_hidden.jl

+25-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
"""
22
rescale_hidden!(rbm, λ::AbstractArray)
33
4-
For continuous hidden units with a scale parameter,
5-
scales parameters such that hidden unit activations
6-
are divided by `λ`. For other hidden units does
7-
nothing. The resulting RBM is equivalent to the
8-
original one.
4+
For continuous hidden units with a scale parameter, scales parameters such that hidden
5+
unit activations are divided by `λ`. For other hidden units does nothing. The resulting RBM
6+
is equivalent to the original one.
97
"""
108
function rescale_hidden!(rbm::RBM, λ::AbstractArray)
119
@assert size(rbm.hidden) == size(λ)
@@ -15,12 +13,28 @@ function rescale_hidden!(rbm::RBM, λ::AbstractArray)
1513
return rbm
1614
end
1715

16+
"""
17+
rescale_weights!(rbm, λ::AbstractArray)
18+
19+
For continuous hidden units with a scale parameter, scales parameters such that the weights
20+
attached to each hidden unit have norm 1.
21+
"""
22+
function rescale_weights!(rbm::RBM)
23+
ω = weight_norms(rbm)
24+
λ = inv.(ω)
25+
return rescale_hidden!(rbm, λ)
26+
end
27+
28+
function weight_norms(rbm::RBM)
29+
w2 = sum(abs2, rbm.w; dims=1:ndims(rbm.visible))
30+
return reshape(sqrt.(w2), size(rbm.hidden))
31+
end
32+
1833
"""
1934
rescale_activations!(layer, λ::AbstractArray)
2035
21-
For continuous layers with scale parameters, re-parameterizes
22-
such that unit activations are divided by `λ`, and returns `true`.
23-
For other layers just returns `false`.
36+
For continuous layers with scale parameters, re-parameterizes such that unit activations
37+
are divided by `λ`, and returns `true`. For other layers, does nothing and returns `false`.
2438
"""
2539
rescale_activations!(layer::Union{Binary,Spin,Potts}, λ::AbstractArray) = false
2640

@@ -29,15 +43,15 @@ must have positive activations. So we dissallow it below. =#
2943

3044
function rescale_activations!(layer::Union{Gaussian,ReLU}, λ::AbstractArray)
3145
@assert size(layer) == size(λ)
32-
@assert all(λ .> 0)
46+
@assert all(>(0), λ)
3347
layer.θ .*= λ
3448
layer.γ .*= λ.^2
3549
return true
3650
end
3751

3852
function rescale_activations!(layer::dReLU, λ::AbstractArray)
3953
@assert size(layer) == size(λ)
40-
@assert all(λ .> 0)
54+
@assert all(>(0), λ)
4155
layer.θp .*= λ
4256
layer.θn .*= λ
4357
layer.γp .*= λ.^2
@@ -47,7 +61,7 @@ end
4761

4862
function rescale_activations!(layer::Union{pReLU,xReLU}, λ::AbstractArray)
4963
@assert size(layer) == size(λ)
50-
@assert all(λ .> 0) # makes life simpler
64+
@assert all(>(0), λ) # it's just simpler
5165
layer.θ .*= λ
5266
layer.Δ .*= λ
5367
layer.γ .*= λ.^2

src/train/pcd.jl

+3-18
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,7 @@ function pcd!(
2121

2222
# gauge
2323
zerosum::Bool = true, # zerosum gauge for Potts layers
24-
rescale::Bool = true, # normalize continuous hidden units to var(h) = 1
25-
26-
# momentum for hidden activity statistics tracking
27-
ρh::Real = 99//100,
28-
ϵh::Real = 1//100, # prevent vanishing var(h) estimate
24+
rescale::Bool = true, # normalize weights to unit norm (for continuous hidden units only)
2925

3026
callback = Returns(nothing), # called for every batch
3127

@@ -36,15 +32,10 @@ function pcd!(
3632
)
3733
@assert size(data) == (size(rbm.visible)..., size(data)[end])
3834
@assert isnothing(wts) || size(data)[end] == length(wts)
39-
@assert ϵh 0
40-
41-
# used to scale hidden unit activities
42-
var_h = total_var_from_inputs(rbm.hidden, inputs_h_from_v(rbm, data); wts)
43-
@assert all(var_h .+ ϵh .> 0)
4435

4536
# gauge constraints
4637
zerosum && zerosum!(rbm)
47-
rescale && rescale_hidden!(rbm, sqrt.(var_h .+ ϵh))
38+
rescale && rescale_weights!(rbm)
4839

4940
# store average weight of each data point
5041
wts_mean = isnothing(wts) ? 1 : mean(wts)
@@ -73,14 +64,8 @@ function pcd!(
7364
batch_weight = isnothing(wts) ? 1 : mean(wd) / wts_mean
7465
*= batch_weight
7566

76-
# Exponential moving average of variance of hidden unit activations.
77-
ρh_eff = ρh ^ batch_weight # effective damp after 'batch_weight' updates
78-
var_h_batch = grad2var(rbm.hidden, -∂d.hidden) # extract hidden unit statistics from gradient
79-
var_h .= ρh_eff * var_h .+ (1 - ρh_eff) * var_h_batch
80-
@assert all(var_h .+ ϵh .> 0)
81-
8267
# reset gauge
83-
rescale && rescale_hidden!(rbm, sqrt.(var_h .+ ϵh))
68+
rescale && rescale_weights!(rbm)
8469
zerosum && zerosum!(rbm)
8570

8671
callback(; rbm, optim, iter, vm, vd, wd)

test/gauge/rescale_hidden.jl

+18-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ import Random
22
using Test: @test, @testset, @inferred
33
using Statistics: mean, var
44
using Random: bitrand, rand!, randn!
5+
using LinearAlgebra: norm
56
using RestrictedBoltzmannMachines: RBM, Binary, free_energy, Gaussian, ReLU, dReLU, pReLU, xReLU,
67
sample_v_from_v, sample_h_from_h, mean_from_inputs, var_from_inputs,
7-
rescale_hidden!, rescale_activations!
8+
rescale_hidden!, rescale_activations!, rescale_weights!, weight_norms
89

910
Random.seed!(23)
1011

@@ -55,3 +56,19 @@ end
5556
@test var(v; dims=2) var_v rtol=0.1
5657
@test var(h; dims=2) var_h ./ λ.^2 rtol=0.1
5758
end
59+
60+
@testset "rescale_weights!" begin
61+
rbm = RBM(Binary((2,)), ReLU((1,)), randn(2,1))
62+
randn!(rbm.visible.θ)
63+
randn!(rbm.hidden.θ)
64+
rand!(rbm.hidden.γ)
65+
rbm.hidden.γ .+= 0.5
66+
67+
v = sample_v_from_v(rbm, bitrand(size(rbm.visible)..., 1000); steps=100)
68+
F = free_energy(rbm, v)
69+
70+
ω = @inferred weight_norms(rbm)
71+
@test ω [norm(rbm.w)]
72+
@inferred rescale_weights!(rbm)
73+
@test free_energy(rbm, v) F .- sum(log, ω)
74+
end

test/layers.jl

+1-10
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using LogExpFunctions: logistic
88
using EllipsisNotation: (..)
99
using QuadGK: quadgk
1010
using RestrictedBoltzmannMachines: RBM, Binary, Spin, Potts, Gaussian, ReLU, dReLU, xReLU, pReLU,
11-
flatten, batch_size, batchmean, batchvar, batchcov, grad2ave, grad2var, drelu_energy,
11+
flatten, batch_size, batchmean, batchvar, batchcov, grad2ave, drelu_energy,
1212
mean_from_inputs, var_from_inputs, meanvar_from_inputs, batchdims, gauss_energy, relu_energy,
1313
std_from_inputs, mean_abs_from_inputs, sample_from_inputs, mode_from_inputs,
1414
energy, cgf, free_energy, cgfs, energies, ∂cgf, vstack, ∂energy, ∂free_energy, binary_rand,
@@ -156,7 +156,6 @@ end
156156
= ∂cgf(layer)
157157
@test only(gs).par vstack((mean_from_inputs(layer),))
158158
@test grad2ave(layer, ∂) mean_from_inputs(layer)
159-
@test grad2var(layer, ∂) var_from_inputs(layer)
160159
end
161160

162161
@testset "Spin" begin
@@ -170,7 +169,6 @@ end
170169
= ∂cgf(layer)
171170
@test only(gs).par vstack((mean_from_inputs(layer),))
172171
@test grad2ave(layer, ∂) mean_from_inputs(layer)
173-
@test grad2var(layer, ∂) var_from_inputs(layer)
174172
end
175173

176174
@testset "Potts" begin
@@ -189,7 +187,6 @@ end
189187
= ∂cgf(layer)
190188
@test only(gs).par vstack((mean_from_inputs(layer),))
191189
@test grad2ave(layer, ∂) mean_from_inputs(layer)
192-
@test grad2var(layer, ∂) var_from_inputs(layer)
193190
end
194191

195192
@testset "Gaussian" begin
@@ -220,7 +217,6 @@ end
220217
@test ∂[1, ..] μ
221218
@test ∂[2, ..] -sign.(layer.γ) .* μ2/2
222219
@test grad2ave(layer, ∂) mean_from_inputs(layer)
223-
@test grad2var(layer, ∂) var_from_inputs(layer)
224220
end
225221

226222
@testset "ReLU" begin
@@ -251,7 +247,6 @@ end
251247
@test ∂[1, ..] μ
252248
@test ∂[2, ..] -sign.(layer.γ) .* μ2/2
253249
@test grad2ave(layer, ∂) mean_from_inputs(layer)
254-
@test grad2var(layer, ∂) var_from_inputs(layer)
255250
end
256251

257252
@testset "pReLU / xReLU / dReLU convert" begin
@@ -372,7 +367,6 @@ end
372367
= @inferred ∂cgf(layer)
373368
@test only(gs).par
374369
@test grad2ave(layer, ∂) mean_from_inputs(layer)
375-
@test grad2var(layer, ∂) var_from_inputs(layer)
376370

377371
# check law of total variance
378372
inputs = randn(size(layer)..., 1000)
@@ -384,7 +378,6 @@ end
384378
ν_ext = batchvar(layer, h_ave; mean = μ)
385379
ν = ν_int + ν_ext # law of total variance
386380
@test grad2ave(layer, ∂) μ
387-
@test grad2var(layer, ∂) ν
388381
μ1, ν1 = total_meanvar_from_inputs(layer, inputs)
389382
@test μ1 μ total_mean_from_inputs(layer, inputs)
390383
@test ν1 ν total_var_from_inputs(layer, inputs)
@@ -399,7 +392,6 @@ end
399392
= ∂cgf(layer)
400393
@test only(gs).par
401394
@test grad2ave(layer, ∂) mean_from_inputs(layer)
402-
@test grad2var(layer, ∂) var_from_inputs(layer)
403395
end
404396

405397
@testset "xReLU" begin
@@ -411,7 +403,6 @@ end
411403
= ∂cgf(layer)
412404
@test only(gs).par
413405
@test grad2ave(layer, ∂) mean_from_inputs(layer)
414-
@test grad2var(layer, ∂) var_from_inputs(layer)
415406
end
416407

417408
@testset "grad2ave $Layer" for Layer in _layers

0 commit comments

Comments
 (0)