Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Apollo optimizer (https://arxiv.org/pdf/2412.05270) #196

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
4 changes: 3 additions & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using Functors: functor, fmap, fmap_with_path,
isleaf, @functor, fmapstructure, children, AbstractWalk
using LinearAlgebra

using Random: randn!

include("interface.jl")
export AbstractRule

Expand All @@ -23,7 +25,7 @@ include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
AccumGrad
AccumGrad, Apollo, NormGrowthCap

VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!"))

Expand Down
143 changes: 143 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,149 @@
return (mt, st, βt .* β), dx′
end

"""
NormGrowthCap(τ = 1.01; ϵ = 1e-8, lb = 1e-7, throw = true, scale = true)

Gradient norm growth limiter. `τ` controls the maximum that the gradient norm can grow from one step to the next, such that
if `||dx||/||dx_prev|| > τ` & `||dx|| > lb`, then `dx = dx * τ*||dx_prev||/(||dx||+ϵ)`
Inspired by [Chen et al.](https://arxiv.org/abs/2410.01623) and used with Apollo in [Zhu et al.](https://arxiv.org/abs/2412.05270), but
with Optimisers.jl this will apply per-tensor instead of per-model. This implementation also introduces `lb` as a hard minimum on the gradient norm threshold,
and never rescales grads below this, preventing a tensor from getting "trapped" near zero. This can be a fixed min, or scaled by the square root of the
number of parameters in the tensor (with `scale = true`).
"""
struct NormGrowthCap <: AbstractRule
tau::Float64
epsilon::Float64
lb::Float64 #Min grad norm, to stop a tensor getting stuck near zero
throw::Bool
scale::Bool
end

NormGrowthCap(τ = 1.01; ϵ = 1e-8, lb = 1e-7, throw = true, scale = true) = NormGrowthCap(τ, ϵ, lb, throw, scale)

init(o::NormGrowthCap, x::AbstractArray{T}) where T = T(0)

function apply!(o::NormGrowthCap, state, x::AbstractArray{T}, dx) where T
current_norm = _norm(dx, 2)
if o.throw && !isfinite(current_norm)
throw(DomainError("gradient has L2-norm $current_norm, for array $(summary(x))"))

Check warning on line 627 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L627

Added line #L627 was not covered by tests
end
if state == 0
return (current_norm), dx
else
#If you're below the hard min, then don't scale
if o.scale
minthresh = o.lb * sqrt(length(dx))
else
minthresh = o.lb

Check warning on line 636 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L636

Added line #L636 was not covered by tests
end
if current_norm < minthresh
return current_norm, dx
end
ratio = current_norm / (state + o.epsilon)
if ratio > o.tau
lambda = T((o.tau * state) / (current_norm + o.epsilon))
return current_norm * lambda, dx * lambda
else
return current_norm, dx
end
end
end

nonfirstdims(x) = prod(size(x)[2:end])

"""
Apollo(opt::AdamW = AdamW(), r::Function = dim -> ceil(Int, sqrt(dim)); u = 100, sort_dims = true)
Apollo(η::Real, args...; kw...)
Apollo(arg, rank::Int; kw...)
Apollo(η::Real, rank::Int; kw...)

Apollo optimizer from Zhu et al. (https://arxiv.org/abs/2412.05270). Tracks moments in a low-rank subspace, aiming for Adam-like behavior with minimal additional memory usage.
First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function
to compute the rank from the second dimension (or the product of all dims > 1) of the weight matrix (or tensor).
"""
struct Apollo{T1, T2} <: AbstractRule
opt::T1
r::T2 #Maps non-first dims to rank
u::Int #Subspace update frequency (T in paper)
sort_dims::Bool #Whether to swap the dims of x and dx when the second dim is smaller than the first
end

function adjust(r::Apollo; kw...)
if (:u in keys(kw)) || (:r in keys(kw)) || (:sort_dims in keys(kw))
@error "Apollo does not support adjusting: u, r, sort_dims"

Check warning on line 672 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L670-L672

Added lines #L670 - L672 were not covered by tests
end
return Apollo(_adjust(r.opt, NamedTuple(kw)), r.r, r.u, r.sort_dims)

Check warning on line 674 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L674

Added line #L674 was not covered by tests
end
adjust(r::Apollo, η::Real) = Apollo(adjust(r.opt, η), r.r, r.u, r.sort_dims)

Check warning on line 676 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L676

Added line #L676 was not covered by tests


Apollo(opt::AdamW = AdamW(), r::Function = dim -> ceil(Int, sqrt(dim)); u = 100, sort_dims = true) = Apollo(opt, r, u, sort_dims)
Apollo(η::Real, args...; kw...) = Apollo(AdamW(η), args...; kw...)
Apollo(arg, rank::Int; kw...) = Apollo(arg, dim -> min(dim, rank); kw...)
Apollo(η::Real, rank::Int; kw...) = Apollo(AdamW(η), rank; kw...)

Check warning on line 682 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L680-L682

Added lines #L680 - L682 were not covered by tests


#Use the base init and apply for 1D arrays
init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x)
apply!(o::Apollo, state, x::AbstractArray{T,1}, dx) where T = apply!(o.opt, state, x, dx)

function init(o::Apollo, x::AbstractArray{T}) where T
first_dim, second_dim = size(x,1), nonfirstdims(x)
if o.sort_dims && second_dim < first_dim
first_dim, second_dim = second_dim, first_dim

Check warning on line 692 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L692

Added line #L692 was not covered by tests
end
rank = o.r(second_dim)
P = similar(x, rank, first_dim)
randn!(P)
P .*= T(sqrt(1/rank))
((similar(x, rank, second_dim) .= 0, similar(x, rank, second_dim) .= 0, o.opt.beta), 1, P)
end


function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T
swapped = false
original_size = size(x)
x = reshape(x, size(x,1), nonfirstdims(x))

dx = Broadcast.materialize(dx) #This is to stop the "gradient type" @lazy test from failing due to reshape.
dx = reshape(dx, size(x,1), nonfirstdims(x))
Comment on lines +707 to +708
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to materialize in matrix case?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For everything except the whatever comes in during the "gradient type" test you don't need materialize. I wasn't 100% sure exactly what is coming in during those tests, so wasn't sure how to separate them from regular matrix/tensors. What do you suggest here?


first_dim, second_dim = size(x,1), size(x,2)
if o.sort_dims && second_dim < first_dim
first_dim, second_dim = second_dim, first_dim
x = x'
dx = dx'
swapped = true

Check warning on line 715 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L712-L715

Added lines #L712 - L715 were not covered by tests
end
(mt, vt, βt), t, P = state
η = T(o.opt.eta) #This is what will get modified by adjust
λ = T(o.opt.lambda)
β = T.(o.opt.beta)
ϵ = T(o.opt.epsilon)
βt = T.(βt)
if mod(t, o.u) == 0
rank = o.r(second_dim)
randn!(P)
P .*= T(sqrt(1/rank))
end
R = P * dx
@.. mt = β[1] * mt + (1 - β[1]) * R
@.. vt = β[2] * vt + (1 - β[2]) * abs2(R)
Rhat = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)

R2sum = sum(abs2, R; dims=1)
Rhat2sum = sum(abs2, Rhat; dims=1)
s = @. sqrt(Rhat2sum) / (sqrt(R2sum) + ϵ)
dx′′ = η * (dx .* s) + λ * x

if swapped
dx′′ = transpose(dx′′)

Check warning on line 739 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L739

Added line #L739 was not covered by tests
end
return ((mt, vt, βt .* β), t+1, P), reshape(dx′′, original_size)
end


"""
WeightDecay(λ = 5e-4)
WeightDecay(; [lambda])
Expand Down
3 changes: 2 additions & 1 deletion test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ RULES = [
# All the rules at default settings:
Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(),
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(),
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), Apollo(),
# A few chained combinations:
OptimiserChain(SignDecay(0.001), Adam(0.001)),
OptimiserChain(ClipNorm(), Adam(0.001)),
OptimiserChain(ClipGrad(0.5), Momentum()),
OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)),
OptimiserChain(NormGrowthCap(1.1), Apollo()),
# Not the default:
RMSProp(centred = true), AdamW(couple=false),
]
Expand Down
Loading