题目:有n个句子,每个句子的长度都小于等于m,现在需要将相邻较短的句子拼接再一起,使得句子的数量最少,并且长度仍然不大于m,而且拼接完之后句子的长度的方差最小。求拼接方式。
解题(自己给自己出题,sent_comb3最优):
# -*- coding: utf-8 -*- import numpy as np def sent_comb1(sent_lens, max_len=35): """先求总长的平均长度,再切分""" if sum(sent_lens) <= max_len: return [list(sent_lens)] avg_len = sum(sent_lens) // np.ceil(sum(sent_lens) / max_len) rlts = [] for sent_len in sent_lens: if (not rlts) or (rlts and sum(rlts[-1]) >= avg_len): rlts.append([sent_len]) elif sum(rlts[-1]) + sent_len <= max_len: rlts[-1].append(sent_len) else: rlts.append([sent_len]) return rlts def sent_comb2(sent_lens, max_len=35): """先求得到每两段的平均长度,再切分""" if sum(sent_lens) <= max_len: return [list(sent_lens)] small_sent_lens = [] for sent_len in sent_lens: if sum(small_sent_lens) + sent_len > 2 * max_len: break small_sent_lens.append(sent_len) small_sent_lens_sum = sum(small_sent_lens) avg_len = small_sent_lens_sum // np.ceil(small_sent_lens_sum / max_len) rlts = [] rlt = [] while len(sent_lens): sent_len = sent_lens[0] if sum(rlt) >= avg_len: rlts.append(rlt) rlts.extend(sent_comb2(sent_lens)) break elif sum(rlt) + sent_len <= max_len: rlt.append(sent_len) sent_lens = sent_lens[1:] else: rlts.append(rlt) rlts.extend(sent_comb2(sent_lens)) break return rlts def sent_comb3(sent_lens, max_len=35): """先切分,再求得每两段的平均长度,再切分,重复以上步骤,直到两次操作的结果一致停止""" if sum(sent_lens) <= max_len: return [list(sent_lens)] rlts = [] for sent_len in sent_lens: if not rlts: rlts.append([sent_len]) elif sum(rlts[-1]) + sent_len <= max_len: rlts[-1].append(sent_len) else: rlts.append([sent_len]) while True: new_rlts = [rlts[0]] i = 1 while i < len(rlts): small_sent_lens = new_rlts[-1] + rlts[i] avg_len = sum(small_sent_lens) // 2 rlt = [] while len(small_sent_lens): sent_len = small_sent_lens[0] if sum(rlt) >= avg_len: break elif sum(rlt) + sent_len <= max_len: rlt.append(sent_len) small_sent_lens = small_sent_lens[1:] else: break new_rlts[-1] = rlt new_rlts.append(small_sent_lens) i += 1 if new_rlts == rlts: break rlts = new_rlts return rlts def main(): max_len = 35 # sent_lens = np.random.randint(1, max_len, size=20) sent_lens = [7, 11, 12, 31, 1, 1, 26, 2, 7, 22, 1, 14, 28, 1, 1, 34, 24, 32, 10, 31] # sent_lens = [34, 1, 34, 1, 30, 1] print(sent_lens) import timeit start_time = timeit.default_timer() comb_rlt1 = sent_comb1(sent_lens, max_len=max_len) print(timeit.default_timer() - start_time) print(comb_rlt1) # print([sum(comb) for comb in comb_rlt1]) start_time = timeit.default_timer() comb_rlt2 = sent_comb2(sent_lens, max_len=max_len) print(timeit.default_timer() - start_time) print(comb_rlt2) # print([sum(comb) for comb in comb_rlt2]) start_time = timeit.default_timer() comb_rlt3 = sent_comb3(sent_lens, max_len=max_len) print(timeit.default_timer() - start_time) print(comb_rlt3) # print([sum(comb) for comb in comb_rlt3]) if __name__ == "__main__": main()
来源:https://www.cnblogs.com/jacen789/p/12070493.html