Algorithm to print all paths with a given sum in a binary tree

后端 未结 18 1302
既然无缘
既然无缘 2020-12-24 07:04

The following is an interview question.

You are given a binary tree (not necessarily BST) in which each node contains a value. Design an algorithm t

相关标签:
18条回答
  • 2020-12-24 07:17
    void printpath(int sum,int arr[],int level,struct node * root)
    {
      int tmp=sum,i;
      if(root == NULL)
      return;
      arr[level]=root->data;
      for(i=level;i>=0;i--)
      tmp-=arr[i];
      if(tmp == 0)
      print(arr,level,i+1);
      printpath(sum,arr,level+1,root->left);
      printpath(sum,arr,level+1,root->right);
    }
     void print(int arr[],int end,int start)
    {  
    
    int i;
    for(i=start;i<=end;i++)
    printf("%d ",arr[i]);
    printf("\n");
    }
    

    complexity(n logn) Space complexity(n)

    0 讨论(0)
  • 2020-12-24 07:20

    Well, this is a tree, not a graph. So, you can do something like this:

    Pseudocode:

    global ResultList
    
    function ProcessNode(CurrentNode, CurrentSum)
        CurrentSum+=CurrentNode->Value
        if (CurrentSum==SumYouAreLookingFor) AddNodeTo ResultList
        for all Children of CurrentNode
              ProcessNode(Child,CurrentSum)
    

    Well, this gives you the paths that start at the root. However, you can just make a tiny change:

        for all Children of CurrentNode
              ProcessNode(Child,CurrentSum)
              ProcessNode(Child,0)
    

    You might need to think about it for a second (I'm busy with other things), but this should basically run the same algorithm rooted at every node in the tree

    EDIT: this actually gives the "end node" only. However, as this is a tree, you can just start at those end nodes and walk back up until you get the required sum.

    EDIT 2: and, of course, if all values are positive then you can abort the descent if your current sum is >= the required one

    0 讨论(0)
  • 2020-12-24 07:20
    # include<stdio.h>
    # include <stdlib.h>
    struct Node
    {
        int data;
        struct Node *left, *right;
    };
    
    struct Node * newNode(int item)
    {
        struct Node *temp =  (struct Node *)malloc(sizeof(struct Node));
        temp->data = item;
        temp->left =  NULL;
        temp->right = NULL;
        return temp;
    }
    void print(int p[], int level, int t){
        int i;
        for(i=t;i<=level;i++){
            printf("\n%d",p[i]);
        }
    }
    void check_paths_with_given_sum(struct Node * root, int da, int path[100], int level){
    
         if(root == NULL)
            return ;
        path[level]=root->data;
        int i;int temp=0;
        for(i=level;i>=0;i--){
            temp=temp+path[i];
            if(temp==da){
                print(path,level,i);
            }
        }
            check_paths_with_given_sum(root->left, da, path,level+1);
            check_paths_with_given_sum(root->right, da, path,level+1);
    
    }
    int main(){
        int par[100];
     struct Node *root = newNode(10);
        root->left = newNode(2);
        root->right = newNode(4);
        root->left->left = newNode(1);
        root->right->right = newNode(5);
        check_paths_with_given_sum(root, 9, par,0);
    
    
    }
    

    This works.....

    0 讨论(0)
  • 2020-12-24 07:21

    Below is the solution using recurssion. We perform a in order traversal of the binary tree, as we move down a level we sum up the total path weight by adding the weight of the current level to the weights of previous levels of the tree, if we hit our sum we then print out the path. This solution will handle cases where we may have more than 1 solution along any given path path.

    Assume you have a binary tree rooted at root.

    #include <iostream>
    #include <vector>
    using namespace std;
    
    class Node
    {
    private:
        Node* left;
        Node* right;
        int value;
    
    public:
        Node(const int value)
        {
            left=NULL;
            right=NULL;
            this->value=value;
        }
    
        void setLeft(Node* left)
        {
            this->left=left;
        }
    
        void setRight(Node* right)
        {
            this->right = right;
        }
    
        Node* getLeft() const
        {
            return left;
        }
    
        Node* getRight() const
        {
            return right;
        }
    
        const int& getValue() const
        {
            return value;
        }
    };
    
    //get maximum height of the tree so we know how much space to allocate for our
    //path vector
    
    int getMaxHeight(Node* root)
    {
        if (root == NULL)
            return 0;
    
        int leftHeight = getMaxHeight(root->getLeft());
        int rightHeight = getMaxHeight(root->getRight());
    
        return max(leftHeight, rightHeight) + 1;
    }
    
    //found our target sum, output the path
    void printPaths(vector<int>& paths, int start, int end)
    {
        for(int i = start; i<=end; i++)
            cerr<<paths[i]<< " ";
    
        cerr<<endl;
    }
    
    void generatePaths(Node* root, vector<int>& paths, int depth, const int sum)
    {
        //base case, empty tree, no path
        if( root == NULL)
            return;
    
        paths[depth] = root->getValue();
        int total =0;
    
        //sum up the weights of the nodes in the path traversed
        //so far, if we hit our target, output the path
        for(int i = depth; i>=0; i--)
        {
            total += paths[i];
            if(total == sum)
                printPaths(paths, i, depth);
        }
    
        //go down 1 level where we will then sum up from that level
        //back up the tree to see if any sub path hits our target sum
        generatePaths(root->getLeft(), paths, depth+1, sum);
        generatePaths(root->getRight(), paths, depth+1, sum);
    }
    
    int main(void)
    {
        vector<int> paths (getMaxHeight(&root));
        generatePaths(&root, paths, 0,0);
    }
    

    space complexity depends on the the height of the tree, assumming this is a balanced tree then space complexity is 0(log n) based on the depth of the recurssion stack. Time complexity O(n Log n) - based on a balanced tree where there are n nodes at each level and at each level n amount of work will be done(summing the paths). We also know the tree height is bounded by O(log n) for a balanced binary tree, so n amount of work done for each level on a balanced binary tree gives a run time of O( n log n)

    0 讨论(0)
  • 2020-12-24 07:21

    This is a O(N) solution

    int fn(root, sum, *count)                                                                               
    {                                                                                                   
        if(root == NULL)                                                                                
            return 0;                                                                                                                                                                                       
        int left =  fn(root->left, sum, count);                                                         
    
        int right = fn(root->left, sum, count);                                                                                                                                                            
    
        if(left == sum)                                                                                 
            *count++;                                                                                   
    
        if(right == sum)                                                                                
            *count++;                                                                                   
    
        if((root->data + left + right) == sum)                                                          
            *count++;                                                                                   
    
        return (root->data + left + right);                                                             
    }
    
    0 讨论(0)
  • 2020-12-24 07:24

    Here's an O(n + numResults) answer (essentially the same as @Somebody's answer, but with all issues resolved):

    1. Do a pre-order, in-order, or post-order traversal of the binary tree.
    2. As you do the traversal, maintain the cumulative sum of node values from the root node to the node above the current node. Let's call this value cumulativeSumBeforeNode.
    3. When you visit a node in the traversal, add it to a hashtable at key cumulativeSumBeforeNode (the value at that key will be a list of nodes).
    4. Compute the difference between cumulativeSumBeforeNode and the target sum. Look up this difference in the hash table.
    5. If the hash table lookup succeeds, it should produce a list of nodes. Each one of those nodes represents the start node of a solution. The current node represents the end node for each corresponding start node. Add each [start node, end node] combination to your list of answers. If the hash table lookup fails, do nothing.
    6. When you've finished visiting a node in the traversal, remove the node from the list stored at key cumulativeSumBeforeNode in the hash table.

    Code:

    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.HashMap;
    import java.util.HashSet;
    import java.util.List;
    import java.util.Map;
    import java.util.Set;
    
    public class BinaryTreePathsWithSum {
        public static void main(String[] args) {
            BinaryTreeNode a = new BinaryTreeNode(5);
            BinaryTreeNode b = new BinaryTreeNode(16);
            BinaryTreeNode c = new BinaryTreeNode(16);
            BinaryTreeNode d = new BinaryTreeNode(4);
            BinaryTreeNode e = new BinaryTreeNode(19);
            BinaryTreeNode f = new BinaryTreeNode(2);
            BinaryTreeNode g = new BinaryTreeNode(15);
            BinaryTreeNode h = new BinaryTreeNode(91);
            BinaryTreeNode i = new BinaryTreeNode(8);
    
            BinaryTreeNode root = a;
            a.left = b;
            a.right = c;
            b.right = e;
            c.right = d;
            e.left = f;
            f.left = g;
            f.right = h;
            h.right = i;
    
            /*
                    5
                  /   \
                16     16
                  \     \
                  19     4
                  /
                 2
                / \
               15  91
                    \
                     8
            */
    
            List<BinaryTreePath> pathsWithSum = getBinaryTreePathsWithSum(root, 112); // 19 => 2 => 91
    
            System.out.println(Arrays.toString(pathsWithSum.toArray()));
        }
    
        public static List<BinaryTreePath> getBinaryTreePathsWithSum(BinaryTreeNode root, int sum) {
            if (root == null) {
                throw new IllegalArgumentException("Must pass non-null binary tree!");
            }
    
            List<BinaryTreePath> paths = new ArrayList<BinaryTreePath>();
            Map<Integer, List<BinaryTreeNode>> cumulativeSumMap = new HashMap<Integer, List<BinaryTreeNode>>();
    
            populateBinaryTreePathsWithSum(root, 0, cumulativeSumMap, sum, paths);
    
            return paths;
        }
    
        private static void populateBinaryTreePathsWithSum(BinaryTreeNode node, int cumulativeSumBeforeNode, Map<Integer, List<BinaryTreeNode>> cumulativeSumMap, int targetSum, List<BinaryTreePath> paths) {
            if (node == null) {
                return;
            }
    
            addToMap(cumulativeSumMap, cumulativeSumBeforeNode, node);
    
            int cumulativeSumIncludingNode = cumulativeSumBeforeNode + node.value;
            int sumToFind = cumulativeSumIncludingNode - targetSum;
    
            if (cumulativeSumMap.containsKey(sumToFind)) {
                List<BinaryTreeNode> candidatePathStartNodes = cumulativeSumMap.get(sumToFind);
    
                for (BinaryTreeNode pathStartNode : candidatePathStartNodes) {
                    paths.add(new BinaryTreePath(pathStartNode, node));
                }
            }
    
            populateBinaryTreePathsWithSum(node.left, cumulativeSumIncludingNode, cumulativeSumMap, targetSum, paths);
            populateBinaryTreePathsWithSum(node.right, cumulativeSumIncludingNode, cumulativeSumMap, targetSum, paths);
    
            removeFromMap(cumulativeSumMap, cumulativeSumBeforeNode);
        }
    
        private static void addToMap(Map<Integer, List<BinaryTreeNode>> cumulativeSumMap, int cumulativeSumBeforeNode, BinaryTreeNode node) {
            if (cumulativeSumMap.containsKey(cumulativeSumBeforeNode)) {
                List<BinaryTreeNode> nodes = cumulativeSumMap.get(cumulativeSumBeforeNode);
                nodes.add(node);
            } else {
                List<BinaryTreeNode> nodes = new ArrayList<BinaryTreeNode>();
                nodes.add(node);
                cumulativeSumMap.put(cumulativeSumBeforeNode, nodes);
            }
        }
    
        private static void removeFromMap(Map<Integer, List<BinaryTreeNode>> cumulativeSumMap, int cumulativeSumBeforeNode) {
            List<BinaryTreeNode> nodes = cumulativeSumMap.get(cumulativeSumBeforeNode);
            nodes.remove(nodes.size() - 1);
        }
    
        private static class BinaryTreeNode {
            public int value;
            public BinaryTreeNode left;
            public BinaryTreeNode right;
    
            public BinaryTreeNode(int value) {
                this.value = value;
            }
    
            public String toString() {
                return this.value + "";
            }
    
            public int hashCode() {
                return Integer.valueOf(this.value).hashCode();
            }
    
            public boolean equals(Object other) {
                return this == other;
            }
        }
    
        private static class BinaryTreePath {
            public BinaryTreeNode start;
            public BinaryTreeNode end;
    
            public BinaryTreePath(BinaryTreeNode start, BinaryTreeNode end) {
                this.start = start;
                this.end = end;
            }
    
            public String toString() {
                return this.start + " to " + this.end;
            }
        }
    }
    
    0 讨论(0)
提交回复
热议问题