Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pattern #49

Merged
merged 13 commits into from
Feb 26, 2025
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
Compose = "a81c6b42-2e10-5240-aca2-a61377ecd94b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FASTX = "c2308a5c-f048-11e8-3e8a-31650f418d12"
Fontconfig = "186bb1d3-e1f7-5a2c-a377-96d770f13627"
Expand Down
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using MolecularEvolution
using Documenter, Literate
using Phylo
using Phylo, Distributions
using Plots
using Compose, Cairo, Fontconfig
using FASTX
Expand Down Expand Up @@ -48,6 +48,7 @@ makedocs(;
"optimization.md",
"ancestors.md",
"generated/viz.md",
"generated/update.md",
],
)

Expand Down
115 changes: 115 additions & 0 deletions examples/update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# # Updating a phylogenetic tree
#=
## Interface

```@docs; canonical=false
AbstractUpdate
StandardUpdate
```
=#

# ## Example

using MolecularEvolution, Plots, Distributions
# Simulate a tree
tree = sim_tree(n = 50)
initial_message = GaussianPartition()
models = BrownianMotion(0.0, 1.0)
internal_message_init!(tree, initial_message)
sample_down!(tree, models)
log_likelihood!(tree, models)
# Add some noise to the branch lengths
for n in getnodelist(tree)
n.branchlength += 100 * rand()
end
log_likelihood!(tree, models)
# Optimize under the brownian motion model
update = MaxLikUpdate(branchlength = 1, nni = 0, root = 1)
tree, models = update(tree, models)
@show log_likelihood!(tree, models)
# ### Set up a Bayesian model sampler
#=
Let's assume the target of inference is not the tree itself, but rather the models.
Assume further that you want to, for a fixed mean drift, sample the variance of the brownian motion model,
with the metropolis algorithm.
=#
# We begin with a struct that defines the model and how it's updated
tree = sim_tree(n = 200)
internal_message_init!(tree, GaussianPartition())
#Simulate brownian motion over the tree
models = BrownianMotion(0.0, 2.0)
sample_down!(tree, models)
mutable struct MyModelSampler{
T1<:ContinuousUnivariateDistribution,
T2<:ContinuousUnivariateDistribution,
} <: ModelsUpdate
acc_ratio::Vector{Int}
log_var_drift_proposal::T1
log_var_drift_prior::T2
mean_drift::Float64
function MyModelSampler(
log_var_drift_proposal::T1,
log_var_drift_prior::T2,
mean_drift::Float64,
) where {T1<:ContinuousUnivariateDistribution, T2<:ContinuousUnivariateDistribution}
new{T1, T2}([0, 0], log_var_drift_proposal, log_var_drift_prior, mean_drift)
end
end
# Then we let this struct implement our [`metropolis_step`](@ref) interface
MolecularEvolution.tr(::MyModelSampler, x::BrownianMotion) = log(x.var_drift)
MolecularEvolution.invtr(modifier::MyModelSampler, x::Float64) =
BrownianMotion(modifier.mean_drift, exp(x))

MolecularEvolution.proposal(modifier::MyModelSampler, curr_value::Float64) =
curr_value + rand(modifier.log_var_drift_proposal)
MolecularEvolution.log_prior(modifier::MyModelSampler, x::Float64) =
logpdf(modifier.log_var_drift_prior, x)
# Now we define what a model update is
function (update::MyModelSampler)(
tree::FelNode,
models::BranchModel;
partition_list = 1:length(tree.message),
)
metropolis_step(update, models) do x::BrownianMotion
log_likelihood!(tree, x)
end
end
# Now we define how the model is collapsed to its parameter
function MolecularEvolution.collapse_models(::MyModelSampler, models::BranchModel)
return models.var_drift
end
# Now we define a Bayesian sampler
update = BayesUpdate(
nni = 0,
branchlength = 0,
models = 1,
models_sampler = MyModelSampler(Normal(0.0, 1.0), Normal(-10.0, 1.0), 0.0),
)
trees, models_samples = metropolis_sample(
update,
tree,
BrownianMotion(0.0, 7.67),
1000,
burn_in = 1000,
collect_models = true,
)

ll(x) = log_likelihood!(tree, BrownianMotion(0.0, x))
prior(x) = logpdf(update.models_update.log_var_drift_prior, log(x)) - log(x)
x_range = 0.1:0.1:5

p1 = histogram(
models_samples,
normalize = :pdf,
alpha = 0.5,
label = "Posterior samples",
xlims = (minimum(x_range), maximum(x_range)),
xlabel = "variance per unit time",
ylabel = "probability density",
)
p2 = plot(x_range, ll, label = "Tree likelihood")

p3 = plot(x_range, prior, label = "Prior")
plot(p1, p2, p3, layout = (1, 3), size = (1100, 400))
#-
plot(models_samples)
26 changes: 26 additions & 0 deletions src/MolecularEvolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ abstract type UnivariateModifier end
abstract type UnivariateOpt <: UnivariateModifier end
abstract type UnivariateSampler <: UnivariateModifier end

abstract type RootUpdate <: Function end
abstract type RootOpt <: RootUpdate end
abstract type RootSample <: RootUpdate end
abstract type UniformRootPositionSample <: RootSample end
abstract type ModelsUpdate <: Function end


abstract type LazyDirection end

#include("core/core.jl")
Expand Down Expand Up @@ -102,6 +109,7 @@ export
combine!,
felsenstein!,
felsenstein_down!,
felsenstein_roundtrip!,
sample_down!,
#endpoint_conditioned_sample_down!,
log_likelihood!,
Expand All @@ -122,10 +130,28 @@ export
nni_optim!,
branchlength_update!,
branchlength_optim!,
root_optim!,
root_position_sample!,
metropolis_sample,
metropolis_step,
copy_tree,

#update
AbstractUpdate,
StandardUpdate,
BayesUpdate,
MaxLikUpdate,
RootUpdate,
RootOpt,
RootSample,
UniformRootPositionSample,
StandardRootOpt,
StandardRootSample,
ModelsUpdate,
StandardModelsUpdate,
collapse_models,


#Tree simulation functions
sim_tree,
standard_tree_sim,
Expand Down
44 changes: 30 additions & 14 deletions src/bayes/sampling.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
"""
metropolis_sample(
update!::Function,
update!::AbstractUpdate,
initial_tree::FelNode,
models::Vector{<:BranchModel},
models,
num_of_samples;
partition_list = 1:length(initial_tree.message),
burn_in = 1000,
sample_interval = 10,
collect_LLs = false,
collect_models = false,
midpoint_rooting = false,
ladderize = false,
)

Samples tree topologies from a posterior distribution using a custom `update!` function.

# Arguments
- `update!`: A function that takes (tree::FelNode, models::Vector{<:BranchModel}) and updates `tree`. `update!` takes (tree::FelNode, models::Vector{<:BranchModel}) and updates `tree`. One call to `update!` corresponds to one iteration of the Metropolis algorithm.
- `update!`: A callable that takes (tree::FelNode, models) and updates `tree` and `models`. One call to `update!` corresponds to one iteration of the Metropolis algorithm.
- `initial_tree`: An initial tree topology with the leaves populated with data, for the likelihood calculation.
- `models`: A list of branch models.
- `num_of_samples`: The number of tree samples drawn from the posterior.
- `partition_list`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to sample with all partitions, the default option).
- `burn_in`: The number of samples discarded at the start of the Markov Chain.
- `sample_interval`: The distance between samples in the underlying Markov Chain (to reduce sample correlation).
- `collect_LLs`: Specifies if the function should return the log-likelihoods of the trees.
- `collect_models`: Specifies if the function should return the models.
- `midpoint_rooting`: Specifies whether the drawn samples should be midpoint rerooted (Important! Should only be used for time-reversible branch models starting in equilibrium).

!!! note
Expand All @@ -29,39 +33,49 @@ Samples tree topologies from a posterior distribution using a custom `update!` f
# Returns
- `samples`: The trees drawn from the posterior. Returns shallow tree copies, which needs to be repopulated before running felsenstein! etc.
- `sample_LLs`: The associated log-likelihoods of the tree (optional).
- `sample_models`: The models drawn from the posterior (optional). The models can be collapsed into it's parameters with `collapse_models`.
"""
function metropolis_sample(
update!::Function,
update!::AbstractUpdate,
initial_tree::FelNode,
models::Vector{<:BranchModel},
models,#::Vector{<:BranchModel},
num_of_samples;
partition_list = 1:length(initial_tree.message),
burn_in = 1000,
sample_interval = 10,
collect_LLs = false,
midpoint_rooting = false,
ladderize = false,
collect_models = false,
)

# The prior over the (log) of the branchlengths should be specified in bl_sampler.
# Furthermore, a non-informative/uniform prior is assumed over the tree topolgies (excluding the branchlengths).

sample_LLs = []
sample_LLs = Float64[]
samples = FelNode[]
tree = deepcopy(initial_tree)
sample_models = []
tree = initial_tree#deepcopy(initial_tree)
iterations = burn_in + num_of_samples * sample_interval

for i = 1:iterations
# Updates the tree topolgy and branchlengths.
update!(tree, models)
tree, models = update!(tree, models, partition_list = partition_list)
if isnothing(tree)
break
end

if (i - burn_in) % sample_interval == 0 && i > burn_in

push!(samples, copy_tree(tree, true))

if collect_LLs
push!(sample_LLs, log_likelihood!(tree, models))
push!(sample_LLs, log_likelihood!(tree, models, partition_list = partition_list))
end

if collect_models
push!(sample_models, collapse_models(update!, models))
end
end

end
Expand All @@ -79,10 +93,15 @@ function metropolis_sample(
end
end

if collect_LLs
if collect_LLs && collect_models
return samples, sample_LLs, sample_models
elseif collect_LLs && !collect_models
return samples, sample_LLs
elseif !collect_LLs && collect_models
return samples, sample_models
end


return samples
end

Expand Down Expand Up @@ -110,10 +129,7 @@ function metropolis_sample(
bl_sampler::UnivariateSampler = BranchlengthSampler(Normal(0, 2), Normal(-1, 1)),
kwargs...,
)
metropolis_sample(initial_tree, models, num_of_samples; kwargs...) do tree, models
nni_update!(softmax_sampler, tree, x -> models)
branchlength_update!(bl_sampler, tree, x -> models)
end
metropolis_sample(BayesUpdate(; branchlength_sampler = bl_sampler), initial_tree, models, num_of_samples; kwargs...)
end

# Below are some functions that help to assess the mixing by looking at the distance between leaf nodes.
Expand Down
Loading
Loading