Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid deepcopy pt. 2 #16

Merged
merged 2 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MolecularEvolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ export

#things the user might overload
copy_partition_to!,
copy_partition,
equilibrium_message,
sample_partition!,
obs2partition!,
Expand Down
6 changes: 3 additions & 3 deletions src/core/nodes/FelNode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ end
function internal_message_init!(tree::FelNode, empty_message::Vector{<:Partition})
for node in getnodelist(tree)
if !isleafnode(node)
node.child_messages = [deepcopy(empty_message) for i in node.children]
node.child_messages = [copy_message(empty_message) for i in node.children]
end
node.message = deepcopy(empty_message)
node.parent_message = deepcopy(empty_message)
node.message = copy_message(empty_message)
node.parent_message = copy_message(empty_message)
end
end

Expand Down
5 changes: 5 additions & 0 deletions src/models/continuous_models/gaussian_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ mutable struct GaussianPartition <: ContinuousPartition
end
end

#Overloading the copy_partition to avoid deepcopy.
function copy_partition(src::GaussianPartition)
return GaussianPartition(src.mean, src.var, src.norm_const)
end

#From the first section of http://www.tina-vision.net/docs/memos/2003-003.pdf
function merge_two_gaussians(g1::GaussianPartition, g2::GaussianPartition)
#Handling some edge cases. These aren't mathematically sensible. A gaussian with "Inf" variance will behave like a 1,1,1,1 vector in discrete felsenstein.
Expand Down
4 changes: 4 additions & 0 deletions src/models/discrete_models/codon_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ mutable struct CodonPartition <: DiscretePartition
zeros(sites),
)
end
function CodonPartition(state, states, sites, scaling; code = universal_code)
@assert size(state) == (states, sites) && states == length(code.sense_codons)
new(state, states, sites, scaling)
end
end

#Make this handle IUPAC ambigs sensible. Any codon compatible with the ambig should get a 1.0
Expand Down
25 changes: 25 additions & 0 deletions src/models/discrete_models/discrete_partitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ function copy_partition_to!(dest::T, src::T) where {T<:DiscretePartition}
dest.scaling .= src.scaling
end

#Overloading the copy_partition to avoid deepcopy.
function copy_partition(src::T) where {T<:DiscretePartition}
return T(copy(src.state), src.states, src.sites, copy(src.scaling))
end

#I should add a constructor that constructs a DiscretePartition from an existing array.
mutable struct CustomDiscretePartition <: DiscretePartition
state::Array{Float64,2}
Expand All @@ -24,6 +29,10 @@ mutable struct CustomDiscretePartition <: DiscretePartition
state_arr .= freq_vec
new(state_arr, length(freq_vec), sites, zeros(sites))
end
function CustomDiscretePartition(state, states, sites, scaling)
@assert size(state) == (states, sites)
new(state, states, sites, scaling)
end
end

mutable struct NucleotidePartition <: DiscretePartition
Expand All @@ -40,6 +49,10 @@ mutable struct NucleotidePartition <: DiscretePartition
state_arr .= freq_vec
new(state_arr, 4, sites, zeros(sites))
end
function NucleotidePartition(state, states, sites, scaling)
@assert size(state) == (states, sites) && states == 4
new(state, states, sites, scaling)
end
end

mutable struct GappyNucleotidePartition <: DiscretePartition
Expand All @@ -56,6 +69,10 @@ mutable struct GappyNucleotidePartition <: DiscretePartition
state_arr .= freq_vec
new(state_arr, 5, sites, zeros(sites))
end
function GappyNucleotidePartitionPartition(state, states, sites, scaling)
@assert size(state) == (states, sites) && states == 5
new(state, states, sites, scaling)
end
end

mutable struct AminoAcidPartition <: DiscretePartition
Expand All @@ -72,6 +89,10 @@ mutable struct AminoAcidPartition <: DiscretePartition
state_arr .= freq_vec
new(state_arr, 20, sites, zeros(sites))
end
function AminoAcidPartition(state, states, sites, scaling)
@assert size(state) == (states, sites) && states == 20
new(state, states, sites, scaling)
end
end

mutable struct GappyAminoAcidPartition <: DiscretePartition
Expand All @@ -88,6 +109,10 @@ mutable struct GappyAminoAcidPartition <: DiscretePartition
state_arr .= freq_vec
new(state_arr, 21, sites, zeros(sites))
end
function GappyAminoAcidPartition(state, states, sites, scaling)
@assert size(state) == (states, sites) && states == 21
new(state, states, sites, scaling)
end
end

"""
Expand Down
14 changes: 14 additions & 0 deletions src/models/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ function copy_partition_to!(dest::GaussianPartition,src::GaussianPartition)
end
=#

#Fallback method. This should be overloaded to avoid deepcopy wherever performance requires it
function copy_partition(src::T) where {T <: Partition}
return deepcopy(src)
end

#Example overloading for GaussianPartition:
#=
function copy_partition(src::GaussianPartition)
return GaussianPartition(src.mean, src.var, src.norm_const)
end
=#

copy_message(msg::Vector{<:Partition}) = [copy_partition(x) for x in msg]

#This is a function shared for all models - perhaps move this elsewhere
function combine!(dest::T, source_arr::Vector{<:T}, wipe::Bool) where {T<:Partition}
#Init to be equal to 1, then multiply everything on.
Expand Down
Loading