Skip to content

Commit 53f75c0

Browse files
authored
Merge pull request #15 from nossleinad/main
Add Brent's method for minimization
2 parents 6930e8d + 3abdcf3 commit 53f75c0

File tree

5 files changed

+156
-9
lines changed

5 files changed

+156
-9
lines changed

src/MolecularEvolution.jl

+6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ abstract type SimulationModel <: BranchModel end #Simulation models typically ca
2929

3030
abstract type StatePath end
3131

32+
abstract type UnivariateOpt end
33+
3234
#include("core/core.jl")
3335
include("core/nodes/nodes.jl")
3436
include("core/algorithms/algorithms.jl")
@@ -116,6 +118,10 @@ export
116118
one_hot_sample,
117119
scaled_prob_domain,
118120
golden_section_maximize,
121+
GoldenSectionOpt,
122+
brents_method_minimize,
123+
BrentsMethodOpt,
124+
univariate_maximize,
119125
unit_transform,
120126
HKY85,
121127
P_from_diagonalized_Q,

src/core/algorithms/branchlength_optim.jl

+13-9
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ function branchlength_optim!(
2727
node::FelNode,
2828
models,
2929
partition_list,
30-
tol,
30+
tol;
31+
bl_optimizer::UnivariateOpt = GoldenSectionOpt()
3132
)
3233

3334
#This bit of code should be identical to the regular downward pass...
@@ -60,6 +61,7 @@ function branchlength_optim!(
6061
models,
6162
partition_list,
6263
tol,
64+
bl_optimizer=bl_optimizer
6365
)
6466
end
6567
#Then combine node.child_messages into node.message...
@@ -72,7 +74,7 @@ function branchlength_optim!(
7274
if !isroot(node)
7375
model_list = models(node)
7476
fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list)
75-
opt = golden_section_maximize(fun, 0 + tol, 1 - tol, unit_transform, tol)
77+
opt = univariate_maximize(fun, 0 + tol, 1 - tol, unit_transform, bl_optimizer, tol)
7678
if fun(opt) > fun(node.branchlength)
7779
node.branchlength = opt
7880
end
@@ -88,24 +90,24 @@ end
8890

8991
#BM: Check if running felsenstein_down! makes a difference.
9092
"""
91-
branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5)
93+
branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt())
9294
93-
Uses golden section search to optimize all branches recursively, maintaining the integrity of the messages.
95+
Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages.
9496
Requires felsenstein!() to have been run first.
9597
models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or
9698
a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another.
9799
partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models).
98-
tol is the tolerance for the golden section search.
100+
tol is the absolute tolerance for the bl_optimizer which defaults to golden section search, and has Brent's method as an option by setting bl_optimizer=BrentsMethodOpt().
99101
"""
100-
function branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5)
102+
function branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt())
101103
temp_message = deepcopy(tree.message)
102104
message_to_set = deepcopy(tree.message)
103105

104106
if partition_list === nothing
105107
partition_list = 1:length(tree.message)
106108
end
107109

108-
branchlength_optim!(temp_message, message_to_set, tree, models, partition_list, tol)
110+
branchlength_optim!(temp_message, message_to_set, tree, models, partition_list, tol, bl_optimizer=bl_optimizer)
109111
end
110112

111113
#Overloading to allow for direct model and model vec inputs
@@ -114,10 +116,12 @@ branchlength_optim!(
114116
models::Vector{<:BranchModel};
115117
partition_list = nothing,
116118
tol = 1e-5,
117-
) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol)
119+
bl_optimizer::UnivariateOpt = GoldenSectionOpt()
120+
) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol, bl_optimizer=bl_optimizer)
118121
branchlength_optim!(
119122
tree::FelNode,
120123
model::BranchModel;
121124
partition_list = nothing,
122125
tol = 1e-5,
123-
) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol)
126+
bl_optimizer::UnivariateOpt = GoldenSectionOpt()
127+
) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, bl_optimizer=bl_optimizer)

src/utils/simple_optim.jl

+133
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ function unit_inv_transform(x::Real; k = 1.0)
1212
x / (x + k)
1313
end
1414

15+
struct GoldenSectionOpt <: UnivariateOpt end
16+
struct BrentsMethodOpt <: UnivariateOpt end
17+
1518
"""
1619
Golden section search.
1720
@@ -67,6 +70,136 @@ function golden_section_maximize(f, a::Real, b::Real, transform, tol::Real)
6770
end
6871
end
6972

73+
"""
74+
univariate_maximize(f, a::Real, b::Real, transform, optimizer::GoldenSectionOpt, tol::Real)
75+
Maximizes `f(x)` using a Golden Section Search. See `?golden_section_maximize`.
76+
# Examples
77+
78+
```jldoctest
79+
julia> f(x) = -(x-2)^2
80+
f (generic function with 1 method)
81+
82+
julia> m = univariate_maximize(f, 1, 5, identity, GoldenSectionOpt(), 1e-10)
83+
2.0000000000051843
84+
```
85+
"""
86+
function univariate_maximize(f, a::Real, b::Real, transform, optimizer::GoldenSectionOpt, tol::Real)
87+
return golden_section_maximize(f, a, b, transform, tol)
88+
end
89+
90+
function brents_pq(x, w, v, fx, fw, fv)
91+
#These are some values used by the SPI in Brent's method
92+
#x_new = x + p / q
93+
p = (x - v)^2 * (fx - fw) - (x - w)^2 * (fx - fv)
94+
q = 2 * ((x - v) * (fx - fw) - (x - w) * (fx - fv))
95+
if q > 0
96+
p = -p
97+
end
98+
q = abs(q)
99+
return p, q
100+
end
101+
102+
function SPI_is_well_behaved(a, b, x, p, q, prev_prev_e, tol)
103+
return (q != 0 && a < x + p / q < b && abs(p / q) < abs(prev_prev_e) / 2 && abs(prev_prev_e) > tol)
104+
end
105+
106+
"""
107+
brents_method_minimize(f, a::Real, b::Real, transform, t::Real; ε::Real=sqrt(eps()))
108+
Brent's method for minimization.
109+
110+
Given a function f with a single local minimum in
111+
the interval (a,b), Brent's method returns an approximation
112+
of the x-value that minimizes f to an accuaracy between 2tol and 3tol,
113+
where tol is a combination of a relative and an absolute tolerance,
114+
tol := ε|x| + t. ε should be no smaller `2*eps`,
115+
and preferably not much less than `sqrt(eps)`, which is also the default value.
116+
eps is defined here as the machine epsilon in double precision.
117+
t should be positive.
118+
119+
The method combines the stability of a Golden Section Search and the superlinear convergence
120+
Successive Parabolic Interpolation has under certain conditions. The method never converges much slower
121+
than a Fibonacci search and for a sufficiently well-behaved f, convergence can be exptected to be superlinear,
122+
with an order that's usually atleast 1.3247...
123+
124+
# Examples
125+
126+
```jldoctest
127+
julia> f(x) = exp(-x) - cos(x)
128+
f (generic function with 1 method)
129+
130+
julia> m = brents_method_minimize(f, -1, 2, identity, 1e-7)
131+
0.5885327257940255
132+
```
133+
134+
From: Richard P. Brent, "Algorithms for Minimization without Derivatives" (1973). Chapter 5.
135+
"""
136+
function brents_method_minimize(f, a::Real, b::Real, transform, t::Real; ε::Real=sqrt(eps))
137+
a, b = min(a, b), max(a, b)
138+
v = w = x = a + invphi2 * (b - a) #x is our best approximation
139+
fv = fw = fx = f(transform(x)) #We must always have that fv >= fw >= fx (1)
140+
141+
e, prev_e = 0, 0 #e denotes the step we take in each cycle
142+
m = (a + b) / 2
143+
tol = ε * abs(x) + t
144+
145+
while abs(x - m) > 2*tol - (b - a) / 2
146+
prev_prev_e = prev_e
147+
prev_e = e
148+
p, q = brents_pq(x, w, v, fx, fw, fv)
149+
if SPI_is_well_behaved(a, b, x, p, q, prev_prev_e, tol)
150+
#Then we do a "parabolic interpolation" step
151+
e = p / q
152+
u = x + e
153+
if u - a < 2*tol || b - u < 2*tol #f must not be evaluated too close to a or b
154+
e = x < m ? tol : -tol
155+
end
156+
else #We fall back to a "golden section" step
157+
prev_e = x < m ? b - x : a - x #We want our prev_prev_e to inherit this value, since two GSS steps two iterations apart differ by a factor of invphi2
158+
e = invphi2 * prev_e
159+
end
160+
if abs(e) < tol #f must not be evaluated too close to x
161+
e = e > 0 ? tol : -tol
162+
end
163+
u = x + e
164+
fu = f(transform(u))
165+
#Update variables such that we satisfy (1) and discard the non-optimal interval
166+
if fu <= fx
167+
if u < x
168+
b = x
169+
else
170+
a = x
171+
end
172+
v, fv = w, fw
173+
w, fw = x, fx
174+
x, fx = u, fu
175+
else
176+
if u < x
177+
a = u
178+
else
179+
b = u
180+
end
181+
if fu <= fw || w == x
182+
v, fv = w, fw
183+
w, fw = u, fu
184+
elseif fu <= fv || v == x || v == w
185+
v, fv = u, fu
186+
end
187+
end
188+
m = (a + b) / 2
189+
tol = ε * abs(x) + t
190+
end
191+
return transform(x)
192+
end
193+
194+
"""
195+
univariate_maximize(f, a::Real, b::Real, transform, optimizer::BrentsMethodOpt, t::Real; ε::Real=sqrt(eps))
196+
Maximizes `f(x)` using Brent's method.
197+
See `?brents_method_minimize`.
198+
"""
199+
function univariate_maximize(f, a::Real, b::Real, transform, optimizer::BrentsMethodOpt, t::Real; ε::Real=sqrt(eps))
200+
return brents_method_minimize(x -> -f(x), a, b, transform, t, ε = ε)
201+
end
202+
70203

71204
#This is SGD on trees, sampling branches (using the stochastic_ll_diffs function).
72205
#Promising, but need a LOT of testing. See the FUBAR notebook for a use example.

test/partition_selection.jl

+2
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ begin
5353
felsenstein_down!(tree, x -> bm_models, partition_list = [2])
5454
felsenstein_down!(tree, x -> bm_models)
5555

56+
#TODO When we use BrentsMethodOpt, check if we gain a speed-up and that we're not catastrophically wrong
5657
branchlength_optim!(tree, bm_models, partition_list = [1])
5758
branchlength_optim!(tree, bm_models, partition_list = [2])
5859
branchlength_optim!(tree, bm_models)
60+
branchlength_optim!(tree, bm_models, bl_optimizer=BrentsMethodOpt())
5961
branchlength_optim!(tree, x -> bm_models, partition_list = [2])
6062
branchlength_optim!(tree, x -> bm_models)
6163

test/test_optim.jl

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ begin
22
f(x) = -(x - 2)^2
33
m = golden_section_maximize(f, 1, 5, identity, 1e-20)
44
@test m == 2.0
5+
m = brents_method_minimize(x -> -f(x), 1, 5, identity, 1e-20)
6+
@test m == 2.0
57
end
68

79
begin

0 commit comments

Comments
 (0)