Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stack based refactor #22

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions src/core/algorithms/generative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th
partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over.
"""
function sample_down!(node::FelNode, models, partition_list)
model_list = models(node)
for part in partition_list
if isroot(node)
forward!(node.message[part], node.parent_message[part], model_list[part], node)
else
forward!(node.message[part], node.parent.message[part], model_list[part], node) #node.parent['.' vs. '_']message[part]
stack = [node]
while !isempty(stack)
node = pop!(stack)
model_list = models(node)
for part in partition_list
if isroot(node)
forward!(node.message[part], node.parent_message[part], model_list[part], node)
else
forward!(node.message[part], node.parent.message[part], model_list[part], node) #node.parent['.' vs. '_']message[part]
end
sample_partition!(node.message[part])
end
sample_partition!(node.message[part])
end
if !isleafnode(node)
for child in node.children
sample_down!(child, models, partition_list)
if !isleafnode(node)
for child in reverse(node.children) #We push! in reverse order because of LazyPartition, so that lazysort! is optimal for both felsenstein! and sample_down!
push!(stack, child)
end
end
end
end
Expand Down
87 changes: 66 additions & 21 deletions src/core/nodes/AbstractTreeNode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,15 @@ function prettyprintstring(node::T, spaces::Int = 0) where {T<:AbstractTreeNode}
end

export getnodelist
function getnodelist(node::T, nodelist::Array{T,1} = T[]) where {T<:AbstractTreeNode}
push!(nodelist, node)
for childnode in node.children #Fixing this to avoid implementing iterate(::FelNode)
getnodelist(childnode, nodelist)
function getnodelist(node::T) where {T<:AbstractTreeNode}
nodelist = []
nodes = [node]
while nodes != []
node = pop!(nodes)
push!(nodelist, node)
for child in node.children
push!(nodes, child)
end
end
return nodelist
end
Expand Down Expand Up @@ -385,23 +390,34 @@ function treedepth(node::T) where {T<:AbstractTreeNode}
end

export getnonleaflist
function getnonleaflist(node::T, nonleaflist::Array{T,1} = T[]) where {T<:AbstractTreeNode}
if !isleafnode(node)
push!(nonleaflist, node)
end
for childnode in node.children
getnonleaflist(childnode, nonleaflist)
function getnonleaflist(node::T) where {T<:AbstractTreeNode}
nonleaflist = []
nodes = [node]
while nodes != []
node = pop!(nodes)
if node.children != []
push!(nonleaflist, node)
for child in node.children
push!(nodes, child)
end
end
end
return nonleaflist
end

export getleaflist
function getleaflist(node::T, leaflist::Array{T,1} = T[]) where {T<:AbstractTreeNode}
if isleafnode(node)
push!(leaflist, node)
end
for childnode in node.children
getleaflist(childnode, leaflist)
function getleaflist(node::T) where {T<:AbstractTreeNode}
leaflist = []
nodes = [node]
while nodes != []
node = pop!(nodes)
if node.children == []
push!(leaflist, node)
else
for child in node.children
push!(nodes, child)
end
end
end
return leaflist
end
Expand Down Expand Up @@ -432,16 +448,45 @@ function ladderize(tree::T) where {T<:AbstractTreeNode}
end

function ladderize!(tree::T) where {T<:AbstractTreeNode}
child_counts = countchildren(tree)
for node in getnodelist(tree)
if length(node.children) != 0
sort!(
node.children,
lt = (x, y) -> length(getnodelist(x)) < length(getnodelist(y)),
)
if !isempty(node.children)
sort!(node.children, lt = (x, y) -> child_counts[x] < child_counts[y])
end
end
end

# Creates a dictionary of all the child counts (including the node itself) which can then be used by ladderize to sort the nodes
function countchildren(tree::T) where {T<:AbstractTreeNode}
# Initialize the dictionary to store the number of children for each node
children_count = Dict{T, Int}()

# Initialize the stack for DFS
stack = [tree]

# Initialize a list to keep track of the post-order traversal
post_order = []

# First pass: Perform DFS and store the nodes in post-order
while !isempty(stack)
node = pop!(stack)
push!(post_order, node)
for child in node.children
push!(stack, child)
end
end

# Second pass: Calculate the number of children for each node in post-order
for node in reverse(post_order)
count = 0
for child in node.children
count += 1 + children_count[child]
end
children_count[node] = count
end

return children_count
end

function getorder(tree::T) where {T<:AbstractTreeNode}
return [node.seqindex for node in getleaflist(tree)]
Expand Down
Loading