Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Feb 11, 2025
1 parent a9c5ef6 commit 149cc75
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 32 deletions.
10 changes: 0 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,4 @@ ConstructionBase = "1.5.8"
EnzymeCore = "0.8.5"
Functors = "0.4.9, 0.5"
Statistics = "1"
Zygote = "0.6.40, 0.7.1"
julia = "1.10"

[extras]
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "EnzymeCore", "StaticArrays", "Zygote"]
6 changes: 3 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ macro def(expr)
# Positional-argument method, has defaults for all but the first arg:
positional = :(function $rule($(names[1]), $(params[2:end]...))
$check_sign_eta
vars = maybe_float.([$(names...)])
vars = $(maybe_float).(($(names...)),($(default_types...)))
return new{typeof.(vars)...}(vars...)
end)
# Keyword-argument method. (Made an inner constructor only to allow
Expand All @@ -283,5 +283,5 @@ macro def(expr)
return esc(expr)
end

maybe_float(x::Number) = float(x)
maybe_float(x) = x
maybe_float(x, T::Type{<:AbstractFloat}) = float(x)
maybe_float(x, T) = x
38 changes: 19 additions & 19 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,16 @@ gradients by an estimate their variance, instead of their second moment.
- Keyword `centred` (or `centered`): Indicates whether to use centred variant
of the algorithm.
"""
struct RMSProp <: AbstractRule
eta::Float64
rho::Float64
epsilon::Float64
struct RMSProp{Teta,Trho,Teps} <: AbstractRule
eta::Teta
rho::Trho
epsilon::Teps
centred::Bool
end

function RMSProp(η, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false)
η < 0 && throw(DomainError(η, "the learning rate cannot be negative"))
RMSProp(η, ρ, ϵ, centred | centered)
return RMSProp(float(η), float(ρ), float(ϵ), centred | centered)
end
RMSProp(; eta = 0.001, rho = 0.9, epsilon = 1e-8, kw...) = RMSProp(eta, rho, epsilon; kw...)

Expand Down Expand Up @@ -155,7 +155,7 @@ end


"""
Rprop(η = 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0))
Rprop(η = 1e-3, ℓ = (0.5, 1.2), Γ = (1e-6, 50.0))
Rprop(; [eta, ell, gamma])
Optimizer using the
Expand All @@ -171,9 +171,9 @@ learning algorithm that depends only on the sign of the gradient.
- Step sizes (`Γ::Tuple == gamma`): Mminimal and maximal allowed step sizes.
"""
@def struct Rprop <: AbstractRule
eta = 1f-3
ell = (5f-1, 1.2f0)
gamma = (1f-6, 50f0)
eta = 1e-3
ell = (0.5, 1.2)
gamma = (1e-6, 50.0)
end

init(o::Rprop, x::AbstractArray) = (zero(x), onevalue(o.eta, x))
Expand Down Expand Up @@ -528,17 +528,17 @@ Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`]
The previous rule, which is closer to the original paper, can be obtained by setting `AdamW(..., couple=false)`.
See [this issue](https://github.com/FluxML/Flux.jl/issues/2433) for more details.
"""
struct AdamW <: AbstractRule
eta::Float64
beta::Tuple{Float64, Float64}
lambda::Float64
epsilon::Float64
struct AdamW{Teta,Tbeta<:Tuple,Tlambda,Teps} <: AbstractRule
eta::Teta
beta::Tbeta
lambda::Tlambda
epsilon::Teps
couple::Bool
end

function AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8; couple::Bool = true)
η < 0 && throw(DomainError(η, "the learning rate cannot be negative"))
AdamW(η, β, λ, ϵ, couple)
return AdamW(float(η), β, float(λ), float(ϵ), couple)
end

AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda= 0.0, epsilon = 1e-8, kw...) =
Expand Down Expand Up @@ -704,12 +704,12 @@ Typically composed with other rules using [`OptimiserChain`](@ref).
See also [`ClipGrad`](@ref).
"""
struct ClipNorm <: AbstractRule
omega::Float64
p::Float64
struct ClipNorm{To,Tp} <: AbstractRule
omega::To
p::Tp
throw::Bool
end
ClipNorm= 10, p = 2; throw::Bool = true) = ClipNorm(ω, p, throw)
ClipNorm= 10, p = 2; throw::Bool = true) = ClipNorm(float(ω), float(p), throw)

init(o::ClipNorm, x::AbstractArray) = nothing

Expand Down
8 changes: 8 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
24 changes: 24 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@testset "@def" begin
Optimisers.@def struct DummyRule
a = 1
b1 = 1.5
b2 = 2.5f0
c = (1.0, 2.0)
end

# no args
r = DummyRule()
@test typeof(r.a) == Int
@test typeof(r.b1) == Float64
@test typeof(r.b2) == Float32
@test typeof(r.c) == Tuple{Float64, Float64}

# some positional args
r = DummyRule(2, 2, 4.5)
@test r.a == 2
@test r.b1 == 2
@test r.b2 == 4.5
@test typeof(r.b1) == Float64 # int promoted to float
@test typeof(r.b2) == Float64 # Float64 not converted to Float32
@test r.c == (1.0, 2.0)
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,4 +557,7 @@ end
@testset verbose=true "Optimisation Rules" begin
include("rules.jl")
end
@testset verbose=true "interface" begin
include("interface.jl")
end
end

0 comments on commit 149cc75

Please sign in to comment.