The following is an interview question.
You are given a binary tree (not necessarily BST) in which each node contains a value. Design an algorithm t
void printpath(int sum,int arr[],int level,struct node * root)
{
int tmp=sum,i;
if(root == NULL)
return;
arr[level]=root->data;
for(i=level;i>=0;i--)
tmp-=arr[i];
if(tmp == 0)
print(arr,level,i+1);
printpath(sum,arr,level+1,root->left);
printpath(sum,arr,level+1,root->right);
}
void print(int arr[],int end,int start)
{
int i;
for(i=start;i<=end;i++)
printf("%d ",arr[i]);
printf("\n");
}
complexity(n logn) Space complexity(n)
Well, this is a tree, not a graph. So, you can do something like this:
Pseudocode:
global ResultList
function ProcessNode(CurrentNode, CurrentSum)
CurrentSum+=CurrentNode->Value
if (CurrentSum==SumYouAreLookingFor) AddNodeTo ResultList
for all Children of CurrentNode
ProcessNode(Child,CurrentSum)
Well, this gives you the paths that start at the root. However, you can just make a tiny change:
for all Children of CurrentNode
ProcessNode(Child,CurrentSum)
ProcessNode(Child,0)
You might need to think about it for a second (I'm busy with other things), but this should basically run the same algorithm rooted at every node in the tree
EDIT: this actually gives the "end node" only. However, as this is a tree, you can just start at those end nodes and walk back up until you get the required sum.
EDIT 2: and, of course, if all values are positive then you can abort the descent if your current sum is >= the required one
# include<stdio.h>
# include <stdlib.h>
struct Node
{
int data;
struct Node *left, *right;
};
struct Node * newNode(int item)
{
struct Node *temp = (struct Node *)malloc(sizeof(struct Node));
temp->data = item;
temp->left = NULL;
temp->right = NULL;
return temp;
}
void print(int p[], int level, int t){
int i;
for(i=t;i<=level;i++){
printf("\n%d",p[i]);
}
}
void check_paths_with_given_sum(struct Node * root, int da, int path[100], int level){
if(root == NULL)
return ;
path[level]=root->data;
int i;int temp=0;
for(i=level;i>=0;i--){
temp=temp+path[i];
if(temp==da){
print(path,level,i);
}
}
check_paths_with_given_sum(root->left, da, path,level+1);
check_paths_with_given_sum(root->right, da, path,level+1);
}
int main(){
int par[100];
struct Node *root = newNode(10);
root->left = newNode(2);
root->right = newNode(4);
root->left->left = newNode(1);
root->right->right = newNode(5);
check_paths_with_given_sum(root, 9, par,0);
}
This works.....
Below is the solution using recurssion. We perform a in order traversal of the binary tree, as we move down a level we sum up the total path weight by adding the weight of the current level to the weights of previous levels of the tree, if we hit our sum we then print out the path. This solution will handle cases where we may have more than 1 solution along any given path path.
Assume you have a binary tree rooted at root.
#include <iostream>
#include <vector>
using namespace std;
class Node
{
private:
Node* left;
Node* right;
int value;
public:
Node(const int value)
{
left=NULL;
right=NULL;
this->value=value;
}
void setLeft(Node* left)
{
this->left=left;
}
void setRight(Node* right)
{
this->right = right;
}
Node* getLeft() const
{
return left;
}
Node* getRight() const
{
return right;
}
const int& getValue() const
{
return value;
}
};
//get maximum height of the tree so we know how much space to allocate for our
//path vector
int getMaxHeight(Node* root)
{
if (root == NULL)
return 0;
int leftHeight = getMaxHeight(root->getLeft());
int rightHeight = getMaxHeight(root->getRight());
return max(leftHeight, rightHeight) + 1;
}
//found our target sum, output the path
void printPaths(vector<int>& paths, int start, int end)
{
for(int i = start; i<=end; i++)
cerr<<paths[i]<< " ";
cerr<<endl;
}
void generatePaths(Node* root, vector<int>& paths, int depth, const int sum)
{
//base case, empty tree, no path
if( root == NULL)
return;
paths[depth] = root->getValue();
int total =0;
//sum up the weights of the nodes in the path traversed
//so far, if we hit our target, output the path
for(int i = depth; i>=0; i--)
{
total += paths[i];
if(total == sum)
printPaths(paths, i, depth);
}
//go down 1 level where we will then sum up from that level
//back up the tree to see if any sub path hits our target sum
generatePaths(root->getLeft(), paths, depth+1, sum);
generatePaths(root->getRight(), paths, depth+1, sum);
}
int main(void)
{
vector<int> paths (getMaxHeight(&root));
generatePaths(&root, paths, 0,0);
}
space complexity depends on the the height of the tree, assumming this is a balanced tree then space complexity is 0(log n) based on the depth of the recurssion stack. Time complexity O(n Log n) - based on a balanced tree where there are n nodes at each level and at each level n amount of work will be done(summing the paths). We also know the tree height is bounded by O(log n) for a balanced binary tree, so n amount of work done for each level on a balanced binary tree gives a run time of O( n log n)
This is a O(N) solution
int fn(root, sum, *count)
{
if(root == NULL)
return 0;
int left = fn(root->left, sum, count);
int right = fn(root->left, sum, count);
if(left == sum)
*count++;
if(right == sum)
*count++;
if((root->data + left + right) == sum)
*count++;
return (root->data + left + right);
}
Here's an O(n + numResults)
answer (essentially the same as @Somebody's answer, but with all issues resolved):
cumulativeSumBeforeNode
.cumulativeSumBeforeNode
(the value at that key will be a list of nodes).cumulativeSumBeforeNode
and the target sum. Look up this difference in the hash table.cumulativeSumBeforeNode
in the hash table.Code:
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class BinaryTreePathsWithSum {
public static void main(String[] args) {
BinaryTreeNode a = new BinaryTreeNode(5);
BinaryTreeNode b = new BinaryTreeNode(16);
BinaryTreeNode c = new BinaryTreeNode(16);
BinaryTreeNode d = new BinaryTreeNode(4);
BinaryTreeNode e = new BinaryTreeNode(19);
BinaryTreeNode f = new BinaryTreeNode(2);
BinaryTreeNode g = new BinaryTreeNode(15);
BinaryTreeNode h = new BinaryTreeNode(91);
BinaryTreeNode i = new BinaryTreeNode(8);
BinaryTreeNode root = a;
a.left = b;
a.right = c;
b.right = e;
c.right = d;
e.left = f;
f.left = g;
f.right = h;
h.right = i;
/*
5
/ \
16 16
\ \
19 4
/
2
/ \
15 91
\
8
*/
List<BinaryTreePath> pathsWithSum = getBinaryTreePathsWithSum(root, 112); // 19 => 2 => 91
System.out.println(Arrays.toString(pathsWithSum.toArray()));
}
public static List<BinaryTreePath> getBinaryTreePathsWithSum(BinaryTreeNode root, int sum) {
if (root == null) {
throw new IllegalArgumentException("Must pass non-null binary tree!");
}
List<BinaryTreePath> paths = new ArrayList<BinaryTreePath>();
Map<Integer, List<BinaryTreeNode>> cumulativeSumMap = new HashMap<Integer, List<BinaryTreeNode>>();
populateBinaryTreePathsWithSum(root, 0, cumulativeSumMap, sum, paths);
return paths;
}
private static void populateBinaryTreePathsWithSum(BinaryTreeNode node, int cumulativeSumBeforeNode, Map<Integer, List<BinaryTreeNode>> cumulativeSumMap, int targetSum, List<BinaryTreePath> paths) {
if (node == null) {
return;
}
addToMap(cumulativeSumMap, cumulativeSumBeforeNode, node);
int cumulativeSumIncludingNode = cumulativeSumBeforeNode + node.value;
int sumToFind = cumulativeSumIncludingNode - targetSum;
if (cumulativeSumMap.containsKey(sumToFind)) {
List<BinaryTreeNode> candidatePathStartNodes = cumulativeSumMap.get(sumToFind);
for (BinaryTreeNode pathStartNode : candidatePathStartNodes) {
paths.add(new BinaryTreePath(pathStartNode, node));
}
}
populateBinaryTreePathsWithSum(node.left, cumulativeSumIncludingNode, cumulativeSumMap, targetSum, paths);
populateBinaryTreePathsWithSum(node.right, cumulativeSumIncludingNode, cumulativeSumMap, targetSum, paths);
removeFromMap(cumulativeSumMap, cumulativeSumBeforeNode);
}
private static void addToMap(Map<Integer, List<BinaryTreeNode>> cumulativeSumMap, int cumulativeSumBeforeNode, BinaryTreeNode node) {
if (cumulativeSumMap.containsKey(cumulativeSumBeforeNode)) {
List<BinaryTreeNode> nodes = cumulativeSumMap.get(cumulativeSumBeforeNode);
nodes.add(node);
} else {
List<BinaryTreeNode> nodes = new ArrayList<BinaryTreeNode>();
nodes.add(node);
cumulativeSumMap.put(cumulativeSumBeforeNode, nodes);
}
}
private static void removeFromMap(Map<Integer, List<BinaryTreeNode>> cumulativeSumMap, int cumulativeSumBeforeNode) {
List<BinaryTreeNode> nodes = cumulativeSumMap.get(cumulativeSumBeforeNode);
nodes.remove(nodes.size() - 1);
}
private static class BinaryTreeNode {
public int value;
public BinaryTreeNode left;
public BinaryTreeNode right;
public BinaryTreeNode(int value) {
this.value = value;
}
public String toString() {
return this.value + "";
}
public int hashCode() {
return Integer.valueOf(this.value).hashCode();
}
public boolean equals(Object other) {
return this == other;
}
}
private static class BinaryTreePath {
public BinaryTreeNode start;
public BinaryTreeNode end;
public BinaryTreePath(BinaryTreeNode start, BinaryTreeNode end) {
this.start = start;
this.end = end;
}
public String toString() {
return this.start + " to " + this.end;
}
}
}