数组切分(句子拼接)

左心房为你撑大大i 提交于 2019-12-24 01:33:07

题目:有n个句子,每个句子的长度都小于等于m,现在需要将相邻较短的句子拼接再一起,使得句子的数量最少,并且长度仍然不大于m,而且拼接完之后句子的长度的方差最小。求拼接方式。

 解题(自己给自己出题,sent_comb3最优):

# -*- coding: utf-8 -*-

import numpy as np


def sent_comb1(sent_lens, max_len=35):
    """先求总长的平均长度,再切分"""
    if sum(sent_lens) <= max_len:
        return [list(sent_lens)]
    avg_len = sum(sent_lens) // np.ceil(sum(sent_lens) / max_len)
    rlts = []
    for sent_len in sent_lens:
        if (not rlts) or (rlts and sum(rlts[-1]) >= avg_len):
            rlts.append([sent_len])
        elif sum(rlts[-1]) + sent_len <= max_len:
            rlts[-1].append(sent_len)
        else:
            rlts.append([sent_len])
    return rlts


def sent_comb2(sent_lens, max_len=35):
    """先求得到每两段的平均长度,再切分"""
    if sum(sent_lens) <= max_len:
        return [list(sent_lens)]
    small_sent_lens = []
    for sent_len in sent_lens:
        if sum(small_sent_lens) + sent_len > 2 * max_len:
            break
        small_sent_lens.append(sent_len)
    small_sent_lens_sum = sum(small_sent_lens)
    avg_len = small_sent_lens_sum // np.ceil(small_sent_lens_sum / max_len)
    rlts = []
    rlt = []
    while len(sent_lens):
        sent_len = sent_lens[0]
        if sum(rlt) >= avg_len:
            rlts.append(rlt)
            rlts.extend(sent_comb2(sent_lens))
            break
        elif sum(rlt) + sent_len <= max_len:
            rlt.append(sent_len)
            sent_lens = sent_lens[1:]
        else:
            rlts.append(rlt)
            rlts.extend(sent_comb2(sent_lens))
            break
    return rlts


def sent_comb3(sent_lens, max_len=35):
    """先切分,再求得每两段的平均长度,再切分,重复以上步骤,直到两次操作的结果一致停止"""
    if sum(sent_lens) <= max_len:
        return [list(sent_lens)]
    rlts = []
    for sent_len in sent_lens:
        if not rlts:
            rlts.append([sent_len])
        elif sum(rlts[-1]) + sent_len <= max_len:
            rlts[-1].append(sent_len)
        else:
            rlts.append([sent_len])
    while True:
        new_rlts = [rlts[0]]
        i = 1
        while i < len(rlts):
            small_sent_lens = new_rlts[-1] + rlts[i]
            avg_len = sum(small_sent_lens) // 2
            rlt = []
            while len(small_sent_lens):
                sent_len = small_sent_lens[0]
                if sum(rlt) >= avg_len:
                    break
                elif sum(rlt) + sent_len <= max_len:
                    rlt.append(sent_len)
                    small_sent_lens = small_sent_lens[1:]
                else:
                    break
            new_rlts[-1] = rlt
            new_rlts.append(small_sent_lens)
            i += 1
        if new_rlts == rlts:
            break
        rlts = new_rlts
    return rlts


def main():
    max_len = 35
    # sent_lens = np.random.randint(1, max_len, size=20)
    sent_lens = [7, 11, 12, 31,  1,  1, 26,  2,  7, 22,  1, 14, 28,  1,  1, 34, 24, 32, 10, 31]
    # sent_lens = [34, 1, 34, 1, 30, 1]
    print(sent_lens)

    import timeit

    start_time = timeit.default_timer()
    comb_rlt1 = sent_comb1(sent_lens, max_len=max_len)
    print(timeit.default_timer() - start_time)
    print(comb_rlt1)
    # print([sum(comb) for comb in comb_rlt1])

    start_time = timeit.default_timer()
    comb_rlt2 = sent_comb2(sent_lens, max_len=max_len)
    print(timeit.default_timer() - start_time)
    print(comb_rlt2)
    # print([sum(comb) for comb in comb_rlt2])

    start_time = timeit.default_timer()
    comb_rlt3 = sent_comb3(sent_lens, max_len=max_len)
    print(timeit.default_timer() - start_time)
    print(comb_rlt3)
    # print([sum(comb) for comb in comb_rlt3])


if __name__ == "__main__":
    main()

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!