Skip to content

Commit b631b49

Browse files
authored
Merge pull request #22 from Glowster/stack-based-refactor
Stack based refactor
2 parents 27f1539 + 4a39c6d commit b631b49

File tree

2 files changed

+81
-32
lines changed

2 files changed

+81
-32
lines changed

src/core/algorithms/generative.jl

+15-11
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,22 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th
2020
partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over.
2121
"""
2222
function sample_down!(node::FelNode, models, partition_list)
23-
model_list = models(node)
24-
for part in partition_list
25-
if isroot(node)
26-
forward!(node.message[part], node.parent_message[part], model_list[part], node)
27-
else
28-
forward!(node.message[part], node.parent.message[part], model_list[part], node) #node.parent['.' vs. '_']message[part]
23+
stack = [node]
24+
while !isempty(stack)
25+
node = pop!(stack)
26+
model_list = models(node)
27+
for part in partition_list
28+
if isroot(node)
29+
forward!(node.message[part], node.parent_message[part], model_list[part], node)
30+
else
31+
forward!(node.message[part], node.parent.message[part], model_list[part], node) #node.parent['.' vs. '_']message[part]
32+
end
33+
sample_partition!(node.message[part])
2934
end
30-
sample_partition!(node.message[part])
31-
end
32-
if !isleafnode(node)
33-
for child in node.children
34-
sample_down!(child, models, partition_list)
35+
if !isleafnode(node)
36+
for child in reverse(node.children) #We push! in reverse order because of LazyPartition, so that lazysort! is optimal for both felsenstein! and sample_down!
37+
push!(stack, child)
38+
end
3539
end
3640
end
3741
end

src/core/nodes/AbstractTreeNode.jl

+66-21
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,15 @@ function prettyprintstring(node::T, spaces::Int = 0) where {T<:AbstractTreeNode}
347347
end
348348

349349
export getnodelist
350-
function getnodelist(node::T, nodelist::Array{T,1} = T[]) where {T<:AbstractTreeNode}
351-
push!(nodelist, node)
352-
for childnode in node.children #Fixing this to avoid implementing iterate(::FelNode)
353-
getnodelist(childnode, nodelist)
350+
function getnodelist(node::T) where {T<:AbstractTreeNode}
351+
nodelist = []
352+
nodes = [node]
353+
while nodes != []
354+
node = pop!(nodes)
355+
push!(nodelist, node)
356+
for child in node.children
357+
push!(nodes, child)
358+
end
354359
end
355360
return nodelist
356361
end
@@ -385,23 +390,34 @@ function treedepth(node::T) where {T<:AbstractTreeNode}
385390
end
386391

387392
export getnonleaflist
388-
function getnonleaflist(node::T, nonleaflist::Array{T,1} = T[]) where {T<:AbstractTreeNode}
389-
if !isleafnode(node)
390-
push!(nonleaflist, node)
391-
end
392-
for childnode in node.children
393-
getnonleaflist(childnode, nonleaflist)
393+
function getnonleaflist(node::T) where {T<:AbstractTreeNode}
394+
nonleaflist = []
395+
nodes = [node]
396+
while nodes != []
397+
node = pop!(nodes)
398+
if node.children != []
399+
push!(nonleaflist, node)
400+
for child in node.children
401+
push!(nodes, child)
402+
end
403+
end
394404
end
395405
return nonleaflist
396406
end
397407

398408
export getleaflist
399-
function getleaflist(node::T, leaflist::Array{T,1} = T[]) where {T<:AbstractTreeNode}
400-
if isleafnode(node)
401-
push!(leaflist, node)
402-
end
403-
for childnode in node.children
404-
getleaflist(childnode, leaflist)
409+
function getleaflist(node::T) where {T<:AbstractTreeNode}
410+
leaflist = []
411+
nodes = [node]
412+
while nodes != []
413+
node = pop!(nodes)
414+
if node.children == []
415+
push!(leaflist, node)
416+
else
417+
for child in node.children
418+
push!(nodes, child)
419+
end
420+
end
405421
end
406422
return leaflist
407423
end
@@ -432,16 +448,45 @@ function ladderize(tree::T) where {T<:AbstractTreeNode}
432448
end
433449

434450
function ladderize!(tree::T) where {T<:AbstractTreeNode}
451+
child_counts = countchildren(tree)
435452
for node in getnodelist(tree)
436-
if length(node.children) != 0
437-
sort!(
438-
node.children,
439-
lt = (x, y) -> length(getnodelist(x)) < length(getnodelist(y)),
440-
)
453+
if !isempty(node.children)
454+
sort!(node.children, lt = (x, y) -> child_counts[x] < child_counts[y])
441455
end
442456
end
443457
end
444458

459+
# Creates a dictionary of all the child counts (including the node itself) which can then be used by ladderize to sort the nodes
460+
function countchildren(tree::T) where {T<:AbstractTreeNode}
461+
# Initialize the dictionary to store the number of children for each node
462+
children_count = Dict{T, Int}()
463+
464+
# Initialize the stack for DFS
465+
stack = [tree]
466+
467+
# Initialize a list to keep track of the post-order traversal
468+
post_order = []
469+
470+
# First pass: Perform DFS and store the nodes in post-order
471+
while !isempty(stack)
472+
node = pop!(stack)
473+
push!(post_order, node)
474+
for child in node.children
475+
push!(stack, child)
476+
end
477+
end
478+
479+
# Second pass: Calculate the number of children for each node in post-order
480+
for node in reverse(post_order)
481+
count = 0
482+
for child in node.children
483+
count += 1 + children_count[child]
484+
end
485+
children_count[node] = count
486+
end
487+
488+
return children_count
489+
end
445490

446491
function getorder(tree::T) where {T<:AbstractTreeNode}
447492
return [node.seqindex for node in getleaflist(tree)]

0 commit comments

Comments
 (0)