I want to split a list into n groups in all possible combinations (allowing for variable group length).
Say, I have the following list:
lst=[1,2,3,4
Following the helpful links from @friendly_dog, I'm attempting to answer my own question by tweaking the functions used in this post. I have a rough solution that works, although I fear it is not particularly efficient and could use some improvement. I end up generating many more sets of partitions than I need, and then filter out for the ones that differ only by sort order.
First, I take these 3 functions from Set partitions in Python:
import itertools
from copy import deepcopy
def slice_by_lengths(lengths, the_list):
for length in lengths:
new = []
for i in range(length):
new.append(the_list.pop(0))
yield new
def partition(number):
return {(x,) + y for x in range(1, number) for y in partition(number-x)} | {(number,)}
def subgrups(my_list):
partitions = partition(len(my_list))
permed = []
for each_partition in partitions:
permed.append(set(itertools.permutations(each_partition, len(each_partition))))
for each_tuple in itertools.chain(*permed):
yield list(slice_by_lengths(each_tuple, deepcopy(my_list)))
I then write a function that wraps the subgrups
function and applies it to each permutation of my original list. I loop through these subgroup permutations and if they are equal in length to the desired number of partitions, I sort them in a way that allows me to identify duplicates. I'm not sure if there is a better approach to this.
def return_partition(my_list,num_groups):
filtered=[]
for perm in itertools.permutations(my_list,len(my_list)):
for sub_group_perm in subgrups(list(perm)):
if len(sub_group_perm)==num_groups:
#sort within each partition
sort1=[sorted(i) for i in sub_group_perm]
#sort by first element of each partition
sort2=sorted(sort1, key=lambda t:t[0])
#sort by the number of elements in each partition
sort3=sorted(sort2, key=lambda t:len(t))
#if this new sorted set of partitions has not been added, add it
if sort3 not in filtered:
filtered.append(sort3)
return filtered
Running it on my original example list, I see that it produces the desired output, tested on two partitions and three partitions.
>>> for i in return_partition([1,2,3,4],2):
... print i
...
[[1], [2, 3, 4]]
[[4], [1, 2, 3]]
[[1, 2], [3, 4]]
[[3], [1, 2, 4]]
[[1, 3], [2, 4]]
[[2], [1, 3, 4]]
[[1, 4], [2, 3]]
>>>
>>> for i in return_partition([1,2,3,4],3):
... print i
...
[[1], [4], [2, 3]]
[[3], [4], [1, 2]]
[[1], [2], [3, 4]]
[[1], [3], [2, 4]]
[[2], [4], [1, 3]]
[[2], [3], [1, 4]]
>>>
We can use the basic recursive algorithm from this answer and modify it to produce partitions of a particular length without having to generate and filter out unwanted partitions.
def sorted_k_partitions(seq, k):
"""Returns a list of all unique k-partitions of `seq`.
Each partition is a list of parts, and each part is a tuple.
The parts in each individual partition will be sorted in shortlex
order (i.e., by length first, then lexicographically).
The overall list of partitions will then be sorted by the length
of their first part, the length of their second part, ...,
the length of their last part, and then lexicographically.
"""
n = len(seq)
groups = [] # a list of lists, currently empty
def generate_partitions(i):
if i >= n:
yield list(map(tuple, groups))
else:
if n - i > k - len(groups):
for group in groups:
group.append(seq[i])
yield from generate_partitions(i + 1)
group.pop()
if len(groups) < k:
groups.append([seq[i]])
yield from generate_partitions(i + 1)
groups.pop()
result = generate_partitions(0)
# Sort the parts in each partition in shortlex order
result = [sorted(ps, key = lambda p: (len(p), p)) for ps in result]
# Sort partitions by the length of each part, then lexicographically.
result = sorted(result, key = lambda ps: (*map(len, ps), ps))
return result
There's quite a lot going on here, so let me explain.
First, we start with a procedural, bottom-up (teminology?) implementation of the same aforementioned recursive algorithm:
def partitions(seq):
"""-> a list of all unique partitions of `seq` in no particular order.
Each partition is a list of parts, and each part is a tuple.
"""
n = len(seq)
groups = [] # a list of lists, currently empty
def generate_partitions(i):
if i >= n:
yield list(map(tuple, groups))
else:
for group in groups
group.append(seq[i])
yield from generate_partitions(i + 1)
group.pop()
groups.append([seq[i]])
yield from generate_partitions(i + 1)
groups.pop()
if n > 0:
return list(generate_partitions(0))
else:
return [[()]]
The main algorithm is in the nested generate_partitions
function. Basically, it walks through the sequence, and for each item, it: 1) puts the item into each of current groups (a.k.a parts) in the working set and recurses; 2) puts the item in its own, new group.
When we reach the end of the sequence (i == n
), we yield a (deep) copy of the working set that we've been building up.
Now, to get partitions of a particular length, we could simply filter or group the results for the ones we're looking for and be done with it, but this approach performs a lot of unnecessary work (i.e. recursive calls) if we just wanted partitions of some length k
.
Note that in the function above, the length of a partition (i.e. the # of groups) is increased whenever:
# this adds a new group (or part) to the partition
groups.append([seq[i]])
yield from generate_partitions(i + 1)
groups.pop()
...is executed. Thus, we limit the size of a partition by simply putting a guard on that block, like so:
def partitions(seq, k):
...
def generate_partitions(i):
...
# only add a new group if the total number would not exceed k
if len(groups) < k:
groups.append([seq[i]])
yield from generate_partitions(i + 1)
groups.pop()
Adding the new parameter and just that line to the partitions
function will now cause it to only generate partitions of length up to k
. This is almost what we want. The problem is that the for
loop still sometimes generates partitions of length less than k
.
In order to prune those recursive branches, we need to only execute the for
loop when we can be sure that we have enough remaining elements in our sequence to expand the working set to a total of k
groups. The number of remaining elements--or elements that haven't yet been placed into a group--is n - i
(or len(seq) - i
). And k - len(groups)
is the number of new groups that we need to add to produce a valid k-partition. If n - i <= k - len(groups)
, then we cannot waste an item by adding it one of the current groups--we must create a new group.
So we simply add another guard, this time to the other recursive branch:
def generate_partitions(i):
...
# only add to current groups if the number of remaining items
# exceeds the number of required new groups.
if n - i > k - len(groups):
for group in groups:
group.append(seq[i])
yield from generate_partitions(i + 1)
group.pop()
# only add a new group if the total number would not exceed k
if len(groups) < k:
groups.append([seq[i]])
yield from generate_partitions(i + 1)
groups.pop()
And with that, you have a working k-partition generator. You could probably collapse some of the recursive calls even further (for example, if there are 3 remaining items and we need 3 more groups, then you already know that you must split each item into their own group), but I wanted to show the function as a slight modification of the basic algorithm which generates all partitions.
The only thing left to do is sort the results. Unfortunately, rather than figuring out how to directly generate the partitions in the desired order (an exercise for a smarter dog), I cheated and just sorted post-generation.
def sorted_k_partitions(seq, k):
...
result = generate_partitions(0)
# Sort the parts in each partition in shortlex order
result = [sorted(ps, key = lambda p: (len(p), p)) for ps in result]
# Sort partitions by the length of each part, then lexicographically.
result = sorted(result, key = lambda ps: (*map(len, ps), ps))
return result
Somewhat self-explanatory, except for the key functions. The first one:
key = lambda p: (len(p), p)
says to sort a sequence by length, then by the sequence itself (which, in Python, are ordered lexicographically by default). The p
stands for "part". This is used to sort the parts/groups within a partition. This key means that, for example, (4,)
precedes (1, 2, 3)
, so that [(1, 2, 3), (4,)]
is sorted as [(4,), (1, 2, 3)]
.
key = lambda ps: (*map(len, ps), ps)
# or for Python versions <3.5: lambda ps: tuple(map(len, ps)) + (ps,)
The ps
here stands for "parts", plural. This one says to sort a sequence by the lengths of each of its elements (which must be sequence themselves), then (lexicographically) by the sequence itself. This is used to sort the partitions with respect to each other, so that, for example, [(4,), (1, 2, 3)]
precedes [(1, 2), (3, 4)]
.
The following:
seq = [1, 2, 3, 4]
for k in 1, 2, 3, 4:
for groups in sorted_k_partitions(seq, k):
print(k, groups)
produces:
1 [(1, 2, 3, 4)]
2 [(1,), (2, 3, 4)]
2 [(2,), (1, 3, 4)]
2 [(3,), (1, 2, 4)]
2 [(4,), (1, 2, 3)]
2 [(1, 2), (3, 4)]
2 [(1, 3), (2, 4)]
2 [(1, 4), (2, 3)]
3 [(1,), (2,), (3, 4)]
3 [(1,), (3,), (2, 4)]
3 [(1,), (4,), (2, 3)]
3 [(2,), (3,), (1, 4)]
3 [(2,), (4,), (1, 3)]
3 [(3,), (4,), (1, 2)]
4 [(1,), (2,), (3,), (4,)]