The question is to find out sum of distances between every two nodes of BinarySearchTree given that every parent-child pair is separated by unit distance. It is to be calculated
First, add four variables to every node. The four variables are sum of distance to the left offspring, sum of distance to the right offspring, number of node in left offspring and number of node in right offspring. Denote them as l, r, nl and nr.
Second, add a total variable to root node to record the sum after each insertion.
The idea is if you have total of current Tree, the new total after inserting a new node is (old total + sum of distance of new node to all other nodes). What you need to calculate in every insertion is sum of distance of new node to all other nodes.
1- Insert the new node with four variable set to zero.
2- Create two temp counter "node travel" and "subtotal" with value zero.
3- Back trace the route from new node to root.
a- go up to parent node
b- add one to node travel
c- add node travel to subtotal
d- add (nr * node travel) + r to subtotal if the new node is on left offspring
e- add node travel to l
f- add one to nl
4- Add subtotal to total
1 - O(n)
2 - O(1)
3 - O(log n), a to f take O(1)
4 - O(1)
We can do this by traverse the tree two times.
First, we need three array
int []left
which stored the sum of the distance of the left sub tree.
int []right
which stored the sum of the distance of the right sub tree.
int []up
which stored the sum of the distance of the parent tree (without the current sub tree).
So, first traversal, for each node, we calculate the left and the right distance. If the node is a leaf, simply return 0, if not, we can have this formula:
int cal(Node node){
int left = cal(node.left);
int right = cal(node.right);
left[node.index] = left;
right[node.index] = right;
//Depend on the current node have left or right node, we add 0,1 or 2 to the final result
int add = (node.left != null && node.right != null)? 2 : node.left != null ? 1 : node.right != null ? 1 : 0;
return left + right + add;
}
Then for the second traversal, we need to add to each node, the total distance from his parent.
1
/ \
2 3
/ \
4 5
For example, for node 1 (root), the total distance is left[1] + right[1] + 2
, up[1] = 0
; (we add 2 as the root has both left and right sub tree, the exact formula for it is:
int add = 0;
if (node.left != null)
add++;
if(node.right != null)
add++;
For node 2 , the total distance is left[2] + right[2] + add + up[1] + right[1] + 1 + addRight
, up[2] = up[1] + right[1] + addRight
. The reason there is a 1
at the end of the formula is because there is an edge from the current node to his parent, so we need to add 1
. Now, I denote the additional distance for the current node is add
, additional distance if there is a left subtree in parent node is addLeft
and similarly addRight
for right subtree.
For node 3, the total distance is up[1] + left[1] + 1 + addLeft
, up[3] = up[1] + left[1] + addLeft
;
For node 4, the total distance is up[2] + right[2] + 1 + addRight
, up[4] = up[2] + right[2] + addRight
;
So depend on the current node is a left or right node, we update the up
accordingly.
The time complexity is O(n)
Yes, you can find the sum distance of the whole tree between every two node by DP in O(n). Briefly, you should know 3 things:
cnt[i] is the node count of the ith-node's sub-tree
dis[i] is the sum distance of every ith-node subtree's node to i-th node
ret[i] is the sum distance of the ith-node subtree between every two node
notice that ret[root]
is answer of the problem, so just calculate ret[i]
right and the problem will be done...
How to calculate ret[i]
? Need the help of cnt[i]
and dis[i]
and solve it recursively.
The key problem is:
Given ret[left] ret[right] dis[left] dis[right] cnt[left] cnt[right] to cal ret[node] dis[node] cnt[node].
(node)
/ \
(left-subtree) (right subtree)
/ \
...(node x_i) ... ...(node y_i)...
important:x_i is the any node in left-subtree(not leaf!)
and y_i is the any node in right-subtree(not leaf either!).
cnt[node]
is easy,just equals cnt[left] + cnt[right] + 1
dis[node]
is not so hard, equals dis[left] + dis[right] + cnt[left] + cnt[right]
. reason: sigma(xi->left) is dis[left]
, so sigma(xi->node) is dis[left] + cnt[left]
.
ret[node]
equal three part:
ret[left] + ret[right]
.dis[node]
.equals sigma(xi -> node -> yj), fixed xi, then we get cnt[left]*distance(xi,node) + sigma(node->yj), then cnt[left]*distance(xi,node) + sigma(node->left->yj),
and it is cnt[left]*distance(x_i,node) + cnt[left] + dis[left]
.
Sum up xi: cnt[left]*(cnt[right]+dis[right]) + cnt[right]*(cnt[left] + dis[left])
, then it is 2*cnt[left]*cnt[right] + dis[left]*cnt[right] + dis[right]*cnt[left]
.
Sum these three parts and we get ret[i]
. Do it recursively, we will get ret[root]
.
My code:
import java.util.Arrays;
public class BSTDistance {
int[] left;
int[] right;
int[] cnt;
int[] ret;
int[] dis;
int nNode;
public BSTDistance(int n) {// n is the number of node
left = new int[n];
right = new int[n];
cnt = new int[n];
ret = new int[n];
dis = new int[n];
Arrays.fill(left,-1);
Arrays.fill(right,-1);
nNode = n;
}
void add(int a, int b)
{
if (left[b] == -1)
{
left[b] = a;
}
else
{
right[b] = a;
}
}
int cal()
{
_cal(0);//assume root's idx is 0
return ret[0];
}
void _cal(int idx)
{
if (left[idx] == -1 && right[idx] == -1)
{
cnt[idx] = 1;
dis[idx] = 0;
ret[idx] = 0;
}
else if (left[idx] != -1 && right[idx] == -1)
{
_cal(left[idx]);
cnt[idx] = cnt[left[idx]] + 1;
dis[idx] = dis[left[idx]] + cnt[left[idx]];
ret[idx] = ret[left[idx]] + dis[idx];
}//left[idx] == -1 and right[idx] != -1 is impossible, guarranted by add(int,int)
else
{
_cal(left[idx]);
_cal(right[idx]);
cnt[idx] = cnt[left[idx]] + 1 + cnt[right[idx]];
dis[idx] = dis[left[idx]] + dis[right[idx]] + cnt[left[idx]] + cnt[right[idx]];
ret[idx] = dis[idx] + ret[left[idx]] + ret[right[idx]] + 2*cnt[left[idx]]*cnt[right[idx]] + dis[left[idx]]*cnt[right[idx]] + dis[right[idx]]*cnt[left[idx]];
}
}
public static void main(String[] args)
{
BSTDistance bst1 = new BSTDistance(3);
bst1.add(1, 0);
bst1.add(2, 0);
// (0)
// / \
//(1) (2)
System.out.println(bst1.cal());
BSTDistance bst2 = new BSTDistance(5);
bst2.add(1, 0);
bst2.add(2, 0);
bst2.add(3, 1);
bst2.add(4, 1);
// (0)
// / \
// (1) (2)
// / \
// (3) (4)
//0 -> 1:1
//0 -> 2:1
//0 -> 3:2
//0 -> 4:2
//1 -> 2:2
//1 -> 3:1
//1 -> 4:1
//2 -> 3:3
//2 -> 4:3
//3 -> 4:2
//2*4+3*2+1*4=18
System.out.println(bst2.cal());
}
}
output:
4
18
For the convenience(of readers to understand my solution), I paste the value of cnt[],dis[] and ret[]
after bst2.cal()
is called:
cnt[] 5 3 1 1 1
dis[] 6 2 0 0 0
ret[] 18 4 0 0 0
PS: It's the solution from UESTC_elfness, it's a simple problem for him , and I'm sayakiss, the problem is not so hard for me..
So you can trust us...
If you mean O(n) per each insertion, then it can be done, assuming you do it after each and every insertion, starting with the root.
0- Record the current sum of the distances. Call it s1: O(1).
1- Insert the new node: O(n).
2- Perform a BFS, starting at this new node.
For each new node you discover, record its distance to the start (new) node, as BFS always does: O(n).
This gives you an array of the distances from the start node to all other nodes.
3- Sum these distances up. Call this s2: O(n).
4- New_sum = s1 + s2: O(1).
This algorithm is thus O(n).