diff --git a/src/core/algorithms/generative.jl b/src/core/algorithms/generative.jl index a96a8f8..5c6c643 100644 --- a/src/core/algorithms/generative.jl +++ b/src/core/algorithms/generative.jl @@ -20,18 +20,22 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over. """ function sample_down!(node::FelNode, models, partition_list) - model_list = models(node) - for part in partition_list - if isroot(node) - forward!(node.message[part], node.parent_message[part], model_list[part], node) - else - forward!(node.message[part], node.parent.message[part], model_list[part], node) #node.parent['.' vs. '_']message[part] + stack = [node] + while !isempty(stack) + node = pop!(stack) + model_list = models(node) + for part in partition_list + if isroot(node) + forward!(node.message[part], node.parent_message[part], model_list[part], node) + else + forward!(node.message[part], node.parent.message[part], model_list[part], node) #node.parent['.' vs. '_']message[part] + end + sample_partition!(node.message[part]) end - sample_partition!(node.message[part]) - end - if !isleafnode(node) - for child in node.children - sample_down!(child, models, partition_list) + if !isleafnode(node) + for child in reverse(node.children) #We push! in reverse order because of LazyPartition, so that lazysort! is optimal for both felsenstein! and sample_down! + push!(stack, child) + end end end end diff --git a/src/core/nodes/AbstractTreeNode.jl b/src/core/nodes/AbstractTreeNode.jl index 82f29f3..072916e 100644 --- a/src/core/nodes/AbstractTreeNode.jl +++ b/src/core/nodes/AbstractTreeNode.jl @@ -347,10 +347,15 @@ function prettyprintstring(node::T, spaces::Int = 0) where {T<:AbstractTreeNode} end export getnodelist -function getnodelist(node::T, nodelist::Array{T,1} = T[]) where {T<:AbstractTreeNode} - push!(nodelist, node) - for childnode in node.children #Fixing this to avoid implementing iterate(::FelNode) - getnodelist(childnode, nodelist) +function getnodelist(node::T) where {T<:AbstractTreeNode} + nodelist = [] + nodes = [node] + while nodes != [] + node = pop!(nodes) + push!(nodelist, node) + for child in node.children + push!(nodes, child) + end end return nodelist end @@ -385,23 +390,34 @@ function treedepth(node::T) where {T<:AbstractTreeNode} end export getnonleaflist -function getnonleaflist(node::T, nonleaflist::Array{T,1} = T[]) where {T<:AbstractTreeNode} - if !isleafnode(node) - push!(nonleaflist, node) - end - for childnode in node.children - getnonleaflist(childnode, nonleaflist) +function getnonleaflist(node::T) where {T<:AbstractTreeNode} + nonleaflist = [] + nodes = [node] + while nodes != [] + node = pop!(nodes) + if node.children != [] + push!(nonleaflist, node) + for child in node.children + push!(nodes, child) + end + end end return nonleaflist end export getleaflist -function getleaflist(node::T, leaflist::Array{T,1} = T[]) where {T<:AbstractTreeNode} - if isleafnode(node) - push!(leaflist, node) - end - for childnode in node.children - getleaflist(childnode, leaflist) +function getleaflist(node::T) where {T<:AbstractTreeNode} + leaflist = [] + nodes = [node] + while nodes != [] + node = pop!(nodes) + if node.children == [] + push!(leaflist, node) + else + for child in node.children + push!(nodes, child) + end + end end return leaflist end @@ -432,16 +448,45 @@ function ladderize(tree::T) where {T<:AbstractTreeNode} end function ladderize!(tree::T) where {T<:AbstractTreeNode} + child_counts = countchildren(tree) for node in getnodelist(tree) - if length(node.children) != 0 - sort!( - node.children, - lt = (x, y) -> length(getnodelist(x)) < length(getnodelist(y)), - ) + if !isempty(node.children) + sort!(node.children, lt = (x, y) -> child_counts[x] < child_counts[y]) end end end +# Creates a dictionary of all the child counts (including the node itself) which can then be used by ladderize to sort the nodes +function countchildren(tree::T) where {T<:AbstractTreeNode} + # Initialize the dictionary to store the number of children for each node + children_count = Dict{T, Int}() + + # Initialize the stack for DFS + stack = [tree] + + # Initialize a list to keep track of the post-order traversal + post_order = [] + + # First pass: Perform DFS and store the nodes in post-order + while !isempty(stack) + node = pop!(stack) + push!(post_order, node) + for child in node.children + push!(stack, child) + end + end + + # Second pass: Calculate the number of children for each node in post-order + for node in reverse(post_order) + count = 0 + for child in node.children + count += 1 + children_count[child] + end + children_count[node] = count + end + + return children_count +end function getorder(tree::T) where {T<:AbstractTreeNode} return [node.seqindex for node in getleaflist(tree)]