问题
I'm am trying to get all possible combinations of 11 values repeated 80 times but filter out cases where the sum is above 1. The code below achieves what I'm trying to do but takes days to run:
import numpy as np
import itertools
unique_values = np.linspace(0.0, 1.0, 11)
lst = []
for p in itertools.product(unique_values , repeat=80):
if sum(p)<=1:
lst.append(p)
The solution above would work but needs way too much time. Also, in this case I would have to periodically save the 'lst' into the disk and free the memory in order to avoid any memory errors. The latter part is fine, but the code needs days (or maybe weeks) to complete.
Is there any alternative?
回答1:
Okay, this would be a bit more efficient, and you can use generator like this, and take your values as needed:
def get_solution(uniques, length, constraint):
if length == 1:
for u in uniques[uniques <= constraint + 1e-8]:
yield u
else:
for u in uniques[uniques <= constraint + 1e-8]:
for s in get_solution(uniques, length - 1, constraint - u):
yield np.hstack((u, s))
g = get_solution(unique_values, 4, 1)
for _ in range(5):
print(next(g))
prints
[0. 0. 0. 0.]
[0. 0. 0. 0.1]
[0. 0. 0. 0.2]
[0. 0. 0. 0.3]
[0. 0. 0. 0.4]
Comparing with your function:
def get_solution_product(uniques, length, constraint):
return np.array([p for p in product(uniques, repeat=length) if np.sum(p) <= constraint + 1e-8])
%timeit np.vstack(list(get_solution(unique_values, 5, 1)))
346 ms ± 29.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit get_solution_product(unique_values, 5, 1)
2.94 s ± 256 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
回答2:
OP simply needs the partitions of 10, but here's some general code I wrote in the meantime.
def find_combinations(values, max_total, repeat):
if not (repeat and max_total > 0):
yield ()
return
for v in values:
if v <= max_total:
for sub_comb in find_combinations(values, max_total - v, repeat - 1):
yield (v,) + sub_comb
def main():
all_combinations = find_combinations(range(1, 11), 10, 80)
unique_combinations = {
tuple(sorted(t))
for t in all_combinations
}
for comb in sorted(unique_combinations):
print(comb)
main()
来源:https://stackoverflow.com/questions/59748726/efficient-cartesian-product-excluding-items