The question is to find out sum of distances between every two nodes of BinarySearchTree given that every parent-child pair is separated by unit distance. It is to be calculated
Yes, you can find the sum distance of the whole tree between every two node by DP in O(n). Briefly, you should know 3 things:
cnt[i] is the node count of the ith-node's sub-tree
dis[i] is the sum distance of every ith-node subtree's node to i-th node
ret[i] is the sum distance of the ith-node subtree between every two node
notice that ret[root]
is answer of the problem, so just calculate ret[i]
right and the problem will be done...
How to calculate ret[i]
? Need the help of cnt[i]
and dis[i]
and solve it recursively.
The key problem is:
Given ret[left] ret[right] dis[left] dis[right] cnt[left] cnt[right] to cal ret[node] dis[node] cnt[node].
(node)
/ \
(left-subtree) (right subtree)
/ \
...(node x_i) ... ...(node y_i)...
important:x_i is the any node in left-subtree(not leaf!)
and y_i is the any node in right-subtree(not leaf either!).
cnt[node]
is easy,just equals cnt[left] + cnt[right] + 1
dis[node]
is not so hard, equals dis[left] + dis[right] + cnt[left] + cnt[right]
. reason: sigma(xi->left) is dis[left]
, so sigma(xi->node) is dis[left] + cnt[left]
.
ret[node]
equal three part:
ret[left] + ret[right]
.dis[node]
.equals sigma(xi -> node -> yj), fixed xi, then we get cnt[left]*distance(xi,node) + sigma(node->yj), then cnt[left]*distance(xi,node) + sigma(node->left->yj),
and it is cnt[left]*distance(x_i,node) + cnt[left] + dis[left]
.
Sum up xi: cnt[left]*(cnt[right]+dis[right]) + cnt[right]*(cnt[left] + dis[left])
, then it is 2*cnt[left]*cnt[right] + dis[left]*cnt[right] + dis[right]*cnt[left]
.
Sum these three parts and we get ret[i]
. Do it recursively, we will get ret[root]
.
My code:
import java.util.Arrays;
public class BSTDistance {
int[] left;
int[] right;
int[] cnt;
int[] ret;
int[] dis;
int nNode;
public BSTDistance(int n) {// n is the number of node
left = new int[n];
right = new int[n];
cnt = new int[n];
ret = new int[n];
dis = new int[n];
Arrays.fill(left,-1);
Arrays.fill(right,-1);
nNode = n;
}
void add(int a, int b)
{
if (left[b] == -1)
{
left[b] = a;
}
else
{
right[b] = a;
}
}
int cal()
{
_cal(0);//assume root's idx is 0
return ret[0];
}
void _cal(int idx)
{
if (left[idx] == -1 && right[idx] == -1)
{
cnt[idx] = 1;
dis[idx] = 0;
ret[idx] = 0;
}
else if (left[idx] != -1 && right[idx] == -1)
{
_cal(left[idx]);
cnt[idx] = cnt[left[idx]] + 1;
dis[idx] = dis[left[idx]] + cnt[left[idx]];
ret[idx] = ret[left[idx]] + dis[idx];
}//left[idx] == -1 and right[idx] != -1 is impossible, guarranted by add(int,int)
else
{
_cal(left[idx]);
_cal(right[idx]);
cnt[idx] = cnt[left[idx]] + 1 + cnt[right[idx]];
dis[idx] = dis[left[idx]] + dis[right[idx]] + cnt[left[idx]] + cnt[right[idx]];
ret[idx] = dis[idx] + ret[left[idx]] + ret[right[idx]] + 2*cnt[left[idx]]*cnt[right[idx]] + dis[left[idx]]*cnt[right[idx]] + dis[right[idx]]*cnt[left[idx]];
}
}
public static void main(String[] args)
{
BSTDistance bst1 = new BSTDistance(3);
bst1.add(1, 0);
bst1.add(2, 0);
// (0)
// / \
//(1) (2)
System.out.println(bst1.cal());
BSTDistance bst2 = new BSTDistance(5);
bst2.add(1, 0);
bst2.add(2, 0);
bst2.add(3, 1);
bst2.add(4, 1);
// (0)
// / \
// (1) (2)
// / \
// (3) (4)
//0 -> 1:1
//0 -> 2:1
//0 -> 3:2
//0 -> 4:2
//1 -> 2:2
//1 -> 3:1
//1 -> 4:1
//2 -> 3:3
//2 -> 4:3
//3 -> 4:2
//2*4+3*2+1*4=18
System.out.println(bst2.cal());
}
}
output:
4
18
For the convenience(of readers to understand my solution), I paste the value of cnt[],dis[] and ret[]
after bst2.cal()
is called:
cnt[] 5 3 1 1 1
dis[] 6 2 0 0 0
ret[] 18 4 0 0 0
PS: It's the solution from UESTC_elfness, it's a simple problem for him , and I'm sayakiss, the problem is not so hard for me..
So you can trust us...