Given a list of size n, Write a program that returns all possible combination of elements contained in each list.
Example:
As you know, the usual solution is recursion. However, out of boredom I once wrote a java method multiNext
to do this without recursion. multiNext
uses an array to keep track of a load of indices in an equivalent system of nested loops.
public static boolean multiNext(int[] current, int[] slotLengths) {
for (int r = current.length - 1; r >= 0; r--) {
if (current[r] < slotLengths[r] - 1) {
current[r]++;
return true;
} else {
current[r] = 0;
}
}
return false;
}
public static void cross(List<List<String>> lists) {
int size = lists.size();
int[] current = new int[size];
int[] slotLengths = new int[size];
for (int i = 0; i < size; i++)
slotLengths[i] = lists.get(i).size();
do {
List<String> temp = new ArrayList<>();
for (int i = 0; i < size; i++)
temp.add(lists.get(i).get(current[i]));
System.out.println(temp);
} while (multiNext(current, slotLengths));
}
public static void main(String[] args) {
cross(Arrays.asList(Arrays.asList("x", "z"), Arrays.asList("a", "b", "c"), Arrays.asList("o", "p")));
}
Edit: I'm answering this in python, because, although it's currently tagged language-agnostic, python is a good, executable pseudo-pseudocode.
If you can write the function in a form that is Tail-recursive, i.e. in a form that looks like def f(x): return f(g(x))
, it's easy to turn it into an iterative form. Unfortunately, you usually won't end up with a tail-recursive call, so you need to know a couple of tricks.
First of all, let's say we have a function that looks like this:
def my_map(func, my_list):
if not my_list:
return []
return [func(my_list[0])] + change_my_list(my_list[1:])
Ok, so it's recursive, but not tail recursive: it's really
def my_map(func, my_list):
if not my_list:
return []
result = [func(my_list[0])] + change_my_list(my_list[1:])
return result
Instead, we need to adjust the function slightly, adding what is traditionally known as an accumulator:
def my_map(func, my_list, acc = [])
if not my_list: return acc
acc = acc + func(my_list[0])
return my_map(func, my_list[1:], acc + func(my_list[0]))
Now, we have a truly tail-recursive function: we've gone from def f(x): return g(f(x))
to def f(x): return f(g(x))
Now, it's quite simple to turn that function into a non-recursive form:
def my_map(func, my_list, acc=[]):
while True: #added
if not my_list: return acc
#return my_map(func, my_list[1:], acc + func(my_list[0])) #deleted
func, my_list, acc = func, my_list[1:], acc + func(my_list[0]) #added
Now, we just tidy up a little bit:
def my_map(func, my_list):
acc = []
while my_list:
acc.append(func(my_list[0])
my_list = my_list[1:]
return acc
Note you can clean it up even further using a for
loop or a list comprehension, but that's left as an exercise for the reader.
Ok, so this was a pathological example, hopefully you'd know that python has a builtin map
function, but the process is the same: transform into a tail recursive form, replace the recursive call with argument reassignment, and tidy up.
So, if you have:
def make_products(list_of_lists):
if not list_of_lists: return []
first_list = list_of_lists[0]
rest = list_of_lists[1:]
return product_of(first_list, make_products(rest))
You can convert it into a tail recursive form
def make_products(list_of_lists, acc=[]):
if not list_of_lists: return acc
first_list = list_of_lists[0]
rest = list_of_lists[1:]
acc = product_of(acc, first_list)
return make_products(rest, acc)
Then, that simplifies to:
def make_products(list_of_lists):
acc=[]
while list_of_lists:
first_list = list_of_lists[0]
rest = list_of_lists[1:]
acc = product_of(acc, first_list)
list_of_lists = rest
return acc
Again, this can be cleaned up further, into a for
loop:
def make_products(list_of_lists):
acc=[]
for lst in list_of_lists:
acc = product_of(acc, lst)
return acc
If you've looked at the builtin functions, you might notice this is somewhat familiar: it's essentially the reduce
function:
def reduce(function, iterable, initializer):
acc = initializer
for x in iterable:
acc = function(acc, x)
return acc
So, the final form is something like
def make_products(list_of_lists):
return reduce(product_of, list_of_lists, []) # the last argument is actually optional here
You then just have to worry about writing the product_of
function.
You do not need recursion. All you need to do is build up a set of intermediate solutions. Here is a non-recursive solution in Python:
# This does NOT use recursion!
def all_comb(list_of_lists):
# We start with a list of just the empty set.
answer = [[]]
for list in list_of_lists:
# new_answer will be the list of combinations including this one.
new_answer = []
# Build up the new answer.
for thing in list:
for prev_list in answer:
new_answer.append(prev_list + [thing])
# Replace the old answer with the new one.
answer = new_answer
# We now have all combinations of all lists.
return answer
# Demonstration that it works.
for comb in all_comb([["x", "y"], ["a", "b", "c"], ["o", "p"]]):
print(" ".join(comb))
Think of it like how you increment a number, e.g. a base 3 number would go:
000
001
002
010
011
...
222
Now think of each digit being the index into each of the nested lists. You will have as many digits as you have nested lists, i.e. the size of the outer list.
The "base" of each digit may differ, and is the size of the corresponding nested list. A "digit" can be a very large number if a nested list is large.
So, you start by creating a list of "digit", or index values, initializing them to 0
. You then print the values of the elements at those indices. You then increment the last index value, rolling over as needed, like you would a normal number, stopping when the first index value rolls over.
Here is a Java implementation using arrays of arrays, i.e. String[][]
. You can easily change to List<List<String>>
or List<String[]>
if needed.
@SafeVarargs
public static void printCombos(String[] ... lists) {
if (lists.length == 0)
throw new IllegalArgumentException("No lists given");
for (String[] list : lists)
if (list.length == 0)
throw new IllegalArgumentException("List is empty");
int[] idx = new int[lists.length];
for (;;) {
// Print combo
for (int i = 0; i < lists.length; i++) {
if (i != 0)
System.out.print(' ');
System.out.print(lists[i][idx[i]]);
}
System.out.println();
// Advance to next combination
for (int i = lists.length - 1; ++idx[i] == lists[i].length; ) {
idx[i] = 0;
if (--i < 0)
return; // We're done
}
}
}
public static void main(String[] args) {
String[][] data = { { "x", "z" }, { "a", "b", "c" }, { "o", "p" } };
printCombos(data);
}
OUTPUT
x a o
x a p
x b o
x b p
x c o
x c p
z a o
z a p
z b o
z b p
z c o
z c p
If you use lists instead of arrays, then the code will use get(int)
, which may not always be good for performance, e.g. for LinkedList
.
If that is the case, replace int[] idx
with an Iterator[]
, initializing each array entry with an iterator for the corresponding list. Resetting a "digit" to 0 would then be done by retrieving a new Iterator
from the list in question.
In this case, they don't even have to be lists, but can be any kind of collection, or more specifically Iterable objects.