Algorithm- Sum of distances between every two nodes of a Binary Search Tree in O(n)?

前端 未结 4 1976
面向向阳花
面向向阳花 2021-02-14 09:05

The question is to find out sum of distances between every two nodes of BinarySearchTree given that every parent-child pair is separated by unit distance. It is to be calculated

4条回答
  •  无人共我
    2021-02-14 09:35

    Yes, you can find the sum distance of the whole tree between every two node by DP in O(n). Briefly, you should know 3 things:

    cnt[i] is the node count of the ith-node's sub-tree
    dis[i] is the sum distance of every ith-node subtree's node to i-th node
    ret[i] is the sum distance of the ith-node subtree between every two node
    

    notice that ret[root] is answer of the problem, so just calculate ret[i] right and the problem will be done... How to calculate ret[i]? Need the help of cnt[i] and dis[i] and solve it recursively. The key problem is:

    Given ret[left] ret[right] dis[left] dis[right] cnt[left] cnt[right] to cal ret[node] dis[node] cnt[node].

                  (node)
              /             \
        (left-subtree) (right subtree)
          /                   \
    ...(node x_i) ...   ...(node y_i)...
    important:x_i is the any node in left-subtree(not leaf!) 
    and y_i is the any node in right-subtree(not leaf either!).
    

    cnt[node] is easy,just equals cnt[left] + cnt[right] + 1

    dis[node] is not so hard, equals dis[left] + dis[right] + cnt[left] + cnt[right]. reason: sigma(xi->left) is dis[left], so sigma(xi->node) is dis[left] + cnt[left].

    ret[node] equal three part:

    1. xi -> xj and yi -> yj, equals ret[left] + ret[right].
    2. xi -> node and yi -> node, equals dis[node].
    3. xi -> yj:

    equals sigma(xi -> node -> yj), fixed xi, then we get cnt[left]*distance(xi,node) + sigma(node->yj), then cnt[left]*distance(xi,node) + sigma(node->left->yj),

    and it is cnt[left]*distance(x_i,node) + cnt[left] + dis[left].

    Sum up xi: cnt[left]*(cnt[right]+dis[right]) + cnt[right]*(cnt[left] + dis[left]), then it is 2*cnt[left]*cnt[right] + dis[left]*cnt[right] + dis[right]*cnt[left].

    Sum these three parts and we get ret[i]. Do it recursively, we will get ret[root].

    My code:

    import java.util.Arrays;
    
    public class BSTDistance {
        int[] left;
        int[] right;
        int[] cnt;
        int[] ret;
        int[] dis;
        int nNode;
        public BSTDistance(int n) {// n is the number of node
            left = new int[n];
            right = new int[n];
            cnt = new int[n];
            ret = new int[n];
            dis = new int[n];
            Arrays.fill(left,-1);
            Arrays.fill(right,-1);
            nNode = n;
        }
        void add(int a, int b)
        {
            if (left[b] == -1)
            {
                left[b] = a;
            }
            else
            {
                right[b] = a;
            }
        }
        int cal()
        {
            _cal(0);//assume root's idx is 0
            return ret[0];
        }
        void _cal(int idx)
        {
            if (left[idx] == -1 && right[idx] == -1)
            {
                cnt[idx] = 1;
                dis[idx] = 0;
                ret[idx] = 0;
            }
            else if (left[idx] != -1  && right[idx] == -1)
            {
                _cal(left[idx]);
                cnt[idx] = cnt[left[idx]] + 1;
                dis[idx] = dis[left[idx]] + cnt[left[idx]];
                ret[idx] = ret[left[idx]] + dis[idx];
            }//left[idx] == -1 and right[idx] != -1 is impossible, guarranted by add(int,int)  
            else 
            {
                _cal(left[idx]);
                _cal(right[idx]);
                cnt[idx] = cnt[left[idx]] + 1 + cnt[right[idx]];
                dis[idx] = dis[left[idx]] + dis[right[idx]] + cnt[left[idx]] + cnt[right[idx]];
                ret[idx] = dis[idx] + ret[left[idx]] + ret[right[idx]] + 2*cnt[left[idx]]*cnt[right[idx]] + dis[left[idx]]*cnt[right[idx]] + dis[right[idx]]*cnt[left[idx]];
            }
        }
        public static void main(String[] args)
        {
            BSTDistance bst1 = new BSTDistance(3);
            bst1.add(1, 0);
            bst1.add(2, 0);
            //   (0)
            //  /   \
            //(1)   (2)
            System.out.println(bst1.cal());
            BSTDistance bst2 = new BSTDistance(5);
            bst2.add(1, 0);
            bst2.add(2, 0);
            bst2.add(3, 1);
            bst2.add(4, 1);
            //       (0)
            //      /   \
            //    (1)   (2)
            //   /   \
            // (3)   (4)
            //0 -> 1:1
            //0 -> 2:1
            //0 -> 3:2
            //0 -> 4:2
            //1 -> 2:2
            //1 -> 3:1
            //1 -> 4:1
            //2 -> 3:3
            //2 -> 4:3
            //3 -> 4:2
            //2*4+3*2+1*4=18
            System.out.println(bst2.cal());
        }
    }
    

    output:

    4
    18
    

    For the convenience(of readers to understand my solution), I paste the value of cnt[],dis[] and ret[] after bst2.cal() is called:

    cnt[] 5 3 1 1 1 
    dis[] 6 2 0 0 0
    ret[] 18 4 0 0 0 
    

    PS: It's the solution from UESTC_elfness, it's a simple problem for him , and I'm sayakiss, the problem is not so hard for me..

    So you can trust us...

提交回复
热议问题