Efficient algorithm to calculate the sum of all k-products

前端 未结 6 1943
青春惊慌失措
青春惊慌失措 2021-02-06 04:49

Suppose you are given a list L of n numbers and an integer k. Is there an efficient way to calculate the sum of all products of

6条回答
  •  执念已碎
    2021-02-06 05:47

    Yes, there is a way. Consider the polynomial

    (X + a[0]) * (X + a[1]) * ... * (X + a[n-1])
    

    Its coefficients are just the sums of the k-products, its degree is n, so you can calculate the sum of all k-products for all k simultaneously in O(n^2) steps.

    After s steps, the coefficient of Xs-k is the sum of the k-products of the first s array elements. The k-products of the first s+1 elements fall into two classes, those involving the (s+1)st element - these have the form a[s]*((k-1)-product of the first s elements) - and those not involving it - these are the k-products of the first s elements.

    Code such that result[i] is the coefficient of Xi (the sum of the (n-i)-products):

    int *k_products_1(int *a, int n){
        int *new, *old = calloc((n+1)*sizeof(int));
        int d, i;
        old[0] = 1;
        for(d = 1; d <= n; ++d){
            new = calloc((n+1)*sizeof(int));
            new[0] = a[d-1]*old[0];
            for(i = 1; i <= d; ++i){
                new[i] = old[i-1] + a[d-1]*old[i];
            }
            free(old);
            old = new;
        }
        return old;
    }
    

    If you only want the sum of the k-products for one k, you can stop the calculation at index n-k, giving an O(n*(n-k)) algorithm - that's good if k >= n/2. To get an O(n*k) algorithm for k <= n/2, you have to organise the coefficient array the other way round, so that result[k] is the coefficient of Xn-k (and stop the calculation at index k if you want only one sum):

    int *k_products_2(int *a, int n){
        int *new, *old = calloc((n+1)*sizeof(int));
        int d, i;
        old[0] = 1;
        for(d = 1; d <= n; ++d){
            new = calloc((n+1)*sizeof(int));
            new[0] = 1;
            for(i = 1; i <= d; ++i){
                new[i] = old[i] + a[d-1]*old[i-1];
            }
            free(old);
            old = new;
        }
        return old;
    }
    

提交回复
热议问题