Skip to content

Commit 56e2c0b

Browse files
committed
Avoid deepcopy
1 parent 53f75c0 commit 56e2c0b

File tree

5 files changed

+73
-47
lines changed

5 files changed

+73
-47
lines changed

src/MolecularEvolution.jl

+1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ export
129129

130130
#things the user might overload
131131
copy_partition_to!,
132+
copy_partition,
132133
equilibrium_message,
133134
sample_partition!,
134135
obs2partition!,

src/core/nodes/FelNode.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ end
6060
function internal_message_init!(tree::FelNode, empty_message::Vector{<:Partition})
6161
for node in getnodelist(tree)
6262
if !isleafnode(node)
63-
node.child_messages = [deepcopy(empty_message) for i in node.children]
63+
node.child_messages = [copy_message(empty_message) for i in node.children]
6464
end
65-
node.message = deepcopy(empty_message)
66-
node.parent_message = deepcopy(empty_message)
65+
node.message = copy_message(empty_message)
66+
node.parent_message = copy_message(empty_message)
6767
end
6868
end
6969

src/models/continuous_models/gaussian_partition.jl

+5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ mutable struct GaussianPartition <: ContinuousPartition
1515
end
1616
end
1717

18+
#Overloading the copy_partition to avoid deepcopy.
19+
function copy_partition(src::GaussianPartition)
20+
return GaussianPartition(src.mean, src.var, src.norm_const)
21+
end
22+
1823
#From the first section of http://www.tina-vision.net/docs/memos/2003-003.pdf
1924
function merge_two_gaussians(g1::GaussianPartition, g2::GaussianPartition)
2025
#Handling some edge cases. These aren't mathematically sensible. A gaussian with "Inf" variance will behave like a 1,1,1,1 vector in discrete felsenstein.

src/models/discrete_models/discrete_partitions.jl

+50-44
Original file line numberDiff line numberDiff line change
@@ -10,84 +10,90 @@ function copy_partition_to!(dest::T, src::T) where {T<:DiscretePartition}
1010
dest.scaling .= src.scaling
1111
end
1212

13+
#Overloading the copy_partition to avoid deepcopy.
14+
function copy_partition(src::T) where {T<:DiscretePartition}
15+
return T(ntuple(i -> copy(getfield(src, i)), fieldcount(T))...)
16+
end
17+
1318
#I should add a constructor that constructs a DiscretePartition from an existing array.
1419
mutable struct CustomDiscretePartition <: DiscretePartition
1520
state::Array{Float64,2}
1621
states::Int
1722
sites::Int
1823
scaling::Array{Float64,1}
19-
function CustomDiscretePartition(states, sites)
20-
new(zeros(states, sites), states, sites, zeros(sites))
21-
end
22-
function CustomDiscretePartition(freq_vec::Vector{Float64}, sites::Int64) #Add this constructor to all partition types
23-
state_arr = zeros(length(freq_vec), sites)
24-
state_arr .= freq_vec
25-
new(state_arr, length(freq_vec), sites, zeros(sites))
26-
end
24+
end
25+
26+
CustomDiscretePartition(states, sites) =
27+
CustomDiscretePartition(zeros(states, sites), states, sites, zeros(sites))
28+
29+
function CustomDiscretePartition(freq_vec::Vector{Float64}, sites::Int64) #Add this constructor to all partition types
30+
state_arr = zeros(length(freq_vec), sites)
31+
state_arr .= freq_vec
32+
return CustomDiscretePartition(state_arr, length(freq_vec), sites, zeros(sites))
2733
end
2834

2935
mutable struct NucleotidePartition <: DiscretePartition
3036
state::Array{Float64,2}
3137
states::Int
3238
sites::Int
3339
scaling::Array{Float64,1}
34-
function NucleotidePartition(sites)
35-
new(zeros(4, sites), 4, sites, zeros(sites))
36-
end
37-
function NucleotidePartition(freq_vec::Vector{Float64}, sites::Int64)
38-
@assert length(freq_vec) == 4
39-
state_arr = zeros(4, sites)
40-
state_arr .= freq_vec
41-
new(state_arr, 4, sites, zeros(sites))
42-
end
40+
end
41+
42+
NucleotidePartition(sites) = NucleotidePartition(zeros(4, sites), 4, sites, zeros(sites))
43+
44+
function NucleotidePartition(freq_vec::Vector{Float64}, sites::Int64)
45+
@assert length(freq_vec) == 4
46+
state_arr = zeros(4, sites)
47+
state_arr .= freq_vec
48+
return NucleotidePartition(state_arr, 4, sites, zeros(sites))
4349
end
4450

4551
mutable struct GappyNucleotidePartition <: DiscretePartition
4652
state::Array{Float64,2}
4753
states::Int
4854
sites::Int
4955
scaling::Array{Float64,1}
50-
function GappyNucleotidePartition(sites)
51-
new(zeros(5, sites), 5, sites, zeros(sites))
52-
end
53-
function GappyNucleotidePartition(freq_vec::Vector{Float64}, sites::Int64)
54-
@assert length(freq_vec) == 5
55-
state_arr = zeros(5, sites)
56-
state_arr .= freq_vec
57-
new(state_arr, 5, sites, zeros(sites))
58-
end
56+
end
57+
58+
GappyNucleotidePartition(sites) = GappyNucleotidePartition(zeros(5, sites), 5, sites, zeros(sites))
59+
60+
function GappyNucleotidePartition(freq_vec::Vector{Float64}, sites::Int64)
61+
@assert length(freq_vec) == 5
62+
state_arr = zeros(5, sites)
63+
state_arr .= freq_vec
64+
return GappyNucleotidePartition(state_arr, 5, sites, zeros(sites))
5965
end
6066

6167
mutable struct AminoAcidPartition <: DiscretePartition
6268
state::Array{Float64,2}
6369
states::Int
6470
sites::Int
6571
scaling::Array{Float64,1}
66-
function AminoAcidPartition(sites)
67-
new(zeros(20, sites), 20, sites, zeros(sites))
68-
end
69-
function AminoAcidPartition(freq_vec::Vector{Float64}, sites::Int64)
70-
@assert length(freq_vec) == 20
71-
state_arr = zeros(20, sites)
72-
state_arr .= freq_vec
73-
new(state_arr, 20, sites, zeros(sites))
74-
end
72+
end
73+
74+
AminoAcidPartition(sites) = AminoAcidPartition(zeros(20, sites), 20, sites, zeros(sites))
75+
76+
function AminoAcidPartition(freq_vec::Vector{Float64}, sites::Int64)
77+
@assert length(freq_vec) == 20
78+
state_arr = zeros(20, sites)
79+
state_arr .= freq_vec
80+
return AminoAcidPartition(state_arr, 20, sites, zeros(sites))
7581
end
7682

7783
mutable struct GappyAminoAcidPartition <: DiscretePartition
7884
state::Array{Float64,2}
7985
states::Int
8086
sites::Int
8187
scaling::Array{Float64,1}
82-
function GappyAminoAcidPartition(sites)
83-
new(zeros(21, sites), 21, sites, zeros(sites))
84-
end
85-
function GappyAminoAcidPartition(freq_vec::Vector{Float64}, sites::Int64)
86-
@assert length(freq_vec) == 21
87-
state_arr = zeros(21, sites)
88-
state_arr .= freq_vec
89-
new(state_arr, 21, sites, zeros(sites))
90-
end
88+
end
89+
90+
GappyAminoAcidPartition(sites) = GappyAminoAcidPartition(zeros(21, sites), 21, sites, zeros(sites))
91+
92+
function GappyAminoAcidPartition(freq_vec::Vector{Float64}, sites::Int64)
93+
@assert length(freq_vec) == 21
94+
state_arr = zeros(21, sites)
95+
state_arr .= freq_vec
96+
return GappyAminoAcidPartition(state_arr, 21, sites, zeros(sites))
9197
end
9298

9399
"""

src/models/models.jl

+14
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@ function copy_partition_to!(dest::GaussianPartition,src::GaussianPartition)
2121
end
2222
=#
2323

24+
#Fallback method. This should be overloaded to avoid deepcopy wherever performance requires it
25+
function copy_partition(src::T) where {T <: Partition}
26+
return deepcopy(src)
27+
end
28+
29+
#Example overloading for GaussianPartition:
30+
#=
31+
function copy_partition(src::GaussianPartition)
32+
return GaussianPartition(src.mean, src.var, src.norm_const)
33+
end
34+
=#
35+
36+
copy_message(msg::Vector{<:Partition}) = [copy_partition(x) for x in msg]
37+
2438
#This is a function shared for all models - perhaps move this elsewhere
2539
function combine!(dest::T, source_arr::Vector{<:T}, wipe::Bool) where {T<:Partition}
2640
#Init to be equal to 1, then multiply everything on.

0 commit comments

Comments
 (0)