-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy path09-VirtualTree.java
128 lines (96 loc) · 3.71 KB
/
09-VirtualTree.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package graphs.advanced;
import graphs.Graph;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
/**
* Tree structure answering queries of the type: <br>
* Given a list of K vertices from a tree find the distance between each pair of vertices in O(K log K) time<br>
*/
@SuppressWarnings({"unused", "unchecked"})
public class VirtualTree {
private final Graph tree;
private final int[] timeIn;
private final int[] timeOut;
private final int[] subtreeSize;
private final boolean[] importantNodes;
private final List<Integer>[] virtualTreeIndex;
private final LowestCommonAncestor lca;
public VirtualTree(Graph t) {
tree = t;
timeIn = new int[tree.vertices];
timeOut = new int[tree.vertices];
subtreeSize = new int[tree.vertices];
importantNodes = new boolean[tree.vertices];
lca = new LowestCommonAncestor(tree);
virtualTreeIndex = (ArrayList<Integer>[]) new ArrayList[tree.vertices];
for (int i = 0; i < tree.vertices; i++) {
virtualTreeIndex[i] = new ArrayList<>();
}
Integer time = 0;
assignVisitTimes(0, -1, time);
}
public long sumDistance(List<Integer> nodes) {
nodes.forEach(node -> {
importantNodes[node] = true;
});
int root = build(nodes);
long result = findSum(root, -1, nodes.size());
nodes.forEach(node -> {
importantNodes[node] = false;
virtualTreeIndex[node].clear();
subtreeSize[node] = 0;
});
return result;
}
private int build(List<Integer> nodes) {
nodes.sort(Comparator.comparingInt(node -> timeIn[node]));
int size = nodes.size();
for (int i = 1; i < size; i++) {
nodes.add(lca.getLowestCommonAncestor(nodes.get(i - 1), nodes.get(i)));
}
nodes.sort(Comparator.comparingInt(node -> timeIn[node]));
nodes = nodes.stream().distinct().collect(Collectors.toList());
ArrayList<Integer> stack = new ArrayList<>();
stack.add(nodes.get(0));
for (int i = 1; i < nodes.size(); i++) {
int node = nodes.get(i);
while (stack.size() >= 2 && !upper(stack.get(stack.size() - 1), node)) {
virtualTreeIndex[stack.get(stack.size() - 2)].add(stack.get(stack.size() - 1));
stack.remove(stack.size() - 1);
}
stack.add(node);
}
while (stack.size() >= 2) {
virtualTreeIndex[stack.get(stack.size() - 2)].add(stack.get(stack.size() - 1));
stack.remove(stack.size() - 1);
}
return stack.get(0);
}
private long findSum(int node, int parent, int total) {
long result = 0;
subtreeSize[node] = importantNodes[node] ? 1 : 0;
for (int child : virtualTreeIndex[node]) {
result += findSum(child, node, total);
subtreeSize[node] += subtreeSize[child];
}
if (parent != -1) {
int distance = lca.getDepth(node) - lca.getDepth(parent);
result += distance * (total - subtreeSize[node]) * subtreeSize[node];
}
return result;
}
private void assignVisitTimes(int node, int parent, Integer time) {
timeIn[node] = time++;
for (Graph.ReachableVertex reachableVertex : tree.adjacencyList.get(node)) {
if (reachableVertex.to != parent) {
assignVisitTimes(reachableVertex.to, node, time);
}
}
timeOut[node] = time++;
}
private boolean upper(int node1, int node2) {
return timeIn[node1] < timeIn[node2] && timeOut[node1] > timeOut[node2];
}
}