From 149cc757d86f6964876631461afd184f31f8f8b6 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 11 Feb 2025 16:44:49 +0100 Subject: [PATCH] tests --- Project.toml | 10 ---------- src/interface.jl | 6 +++--- src/rules.jl | 38 +++++++++++++++++++------------------- test/Project.toml | 8 ++++++++ test/interface.jl | 24 ++++++++++++++++++++++++ test/runtests.jl | 3 +++ 6 files changed, 57 insertions(+), 32 deletions(-) create mode 100644 test/Project.toml create mode 100644 test/interface.jl diff --git a/Project.toml b/Project.toml index ea28243..e5e89f0 100644 --- a/Project.toml +++ b/Project.toml @@ -26,14 +26,4 @@ ConstructionBase = "1.5.8" EnzymeCore = "0.8.5" Functors = "0.4.9, 0.5" Statistics = "1" -Zygote = "0.6.40, 0.7.1" julia = "1.10" - -[extras] -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[targets] -test = ["Test", "EnzymeCore", "StaticArrays", "Zygote"] diff --git a/src/interface.jl b/src/interface.jl index d85299a..8f7d700 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -273,7 +273,7 @@ macro def(expr) # Positional-argument method, has defaults for all but the first arg: positional = :(function $rule($(names[1]), $(params[2:end]...)) $check_sign_eta - vars = maybe_float.([$(names...)]) + vars = $(maybe_float).(($(names...)),($(default_types...))) return new{typeof.(vars)...}(vars...) end) # Keyword-argument method. (Made an inner constructor only to allow @@ -283,5 +283,5 @@ macro def(expr) return esc(expr) end -maybe_float(x::Number) = float(x) -maybe_float(x) = x +maybe_float(x, T::Type{<:AbstractFloat}) = float(x) +maybe_float(x, T) = x diff --git a/src/rules.jl b/src/rules.jl index 0cd8d30..408bd4b 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -114,16 +114,16 @@ gradients by an estimate their variance, instead of their second moment. - Keyword `centred` (or `centered`): Indicates whether to use centred variant of the algorithm. """ -struct RMSProp <: AbstractRule - eta::Float64 - rho::Float64 - epsilon::Float64 +struct RMSProp{Teta,Trho,Teps} <: AbstractRule + eta::Teta + rho::Trho + epsilon::Teps centred::Bool end function RMSProp(η, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false) η < 0 && throw(DomainError(η, "the learning rate cannot be negative")) - RMSProp(η, ρ, ϵ, centred | centered) + return RMSProp(float(η), float(ρ), float(ϵ), centred | centered) end RMSProp(; eta = 0.001, rho = 0.9, epsilon = 1e-8, kw...) = RMSProp(eta, rho, epsilon; kw...) @@ -155,7 +155,7 @@ end """ - Rprop(η = 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0)) + Rprop(η = 1e-3, ℓ = (0.5, 1.2), Γ = (1e-6, 50.0)) Rprop(; [eta, ell, gamma]) Optimizer using the @@ -171,9 +171,9 @@ learning algorithm that depends only on the sign of the gradient. - Step sizes (`Γ::Tuple == gamma`): Mminimal and maximal allowed step sizes. """ @def struct Rprop <: AbstractRule - eta = 1f-3 - ell = (5f-1, 1.2f0) - gamma = (1f-6, 50f0) + eta = 1e-3 + ell = (0.5, 1.2) + gamma = (1e-6, 50.0) end init(o::Rprop, x::AbstractArray) = (zero(x), onevalue(o.eta, x)) @@ -528,17 +528,17 @@ Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`] The previous rule, which is closer to the original paper, can be obtained by setting `AdamW(..., couple=false)`. See [this issue](https://github.com/FluxML/Flux.jl/issues/2433) for more details. """ -struct AdamW <: AbstractRule - eta::Float64 - beta::Tuple{Float64, Float64} - lambda::Float64 - epsilon::Float64 +struct AdamW{Teta,Tbeta<:Tuple,Tlambda,Teps} <: AbstractRule + eta::Teta + beta::Tbeta + lambda::Tlambda + epsilon::Teps couple::Bool end function AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8; couple::Bool = true) η < 0 && throw(DomainError(η, "the learning rate cannot be negative")) - AdamW(η, β, λ, ϵ, couple) + return AdamW(float(η), β, float(λ), float(ϵ), couple) end AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda= 0.0, epsilon = 1e-8, kw...) = @@ -704,12 +704,12 @@ Typically composed with other rules using [`OptimiserChain`](@ref). See also [`ClipGrad`](@ref). """ -struct ClipNorm <: AbstractRule - omega::Float64 - p::Float64 +struct ClipNorm{To,Tp} <: AbstractRule + omega::To + p::Tp throw::Bool end -ClipNorm(ω = 10, p = 2; throw::Bool = true) = ClipNorm(ω, p, throw) +ClipNorm(ω = 10, p = 2; throw::Bool = true) = ClipNorm(float(ω), float(p), throw) init(o::ClipNorm, x::AbstractArray) = nothing diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..a3acc09 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,8 @@ +[deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/interface.jl b/test/interface.jl new file mode 100644 index 0000000..e94d2b4 --- /dev/null +++ b/test/interface.jl @@ -0,0 +1,24 @@ +@testset "@def" begin + Optimisers.@def struct DummyRule + a = 1 + b1 = 1.5 + b2 = 2.5f0 + c = (1.0, 2.0) + end + + # no args + r = DummyRule() + @test typeof(r.a) == Int + @test typeof(r.b1) == Float64 + @test typeof(r.b2) == Float32 + @test typeof(r.c) == Tuple{Float64, Float64} + + # some positional args + r = DummyRule(2, 2, 4.5) + @test r.a == 2 + @test r.b1 == 2 + @test r.b2 == 4.5 + @test typeof(r.b1) == Float64 # int promoted to float + @test typeof(r.b2) == Float64 # Float64 not converted to Float32 + @test r.c == (1.0, 2.0) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 5b2da55..8cdcffc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -557,4 +557,7 @@ end @testset verbose=true "Optimisation Rules" begin include("rules.jl") end + @testset verbose=true "interface" begin + include("interface.jl") + end end