How to compute sum of binomial more efficiently?

前端 未结 3 919
太阳男子 2020-12-20 09:22

I must calculate an equation as follows:

where k1,k2 are given. I am using MATLAB to compute P. I think I have a correct implementation fo

  • 2020-12-20 09:43

    rahnema1's answer has a very good approach: create a table of values that you generate once and access later (as well as some other clever optimizations).

    One thing I would change is the way the binomial coefficients are calculated. If you look at calculating the factorials for nchoosek(n, k) and nchoosek(n, k+1), you're recalculating n! both times, and for k+1, you're recalculating k! and multiplying it by k+1. (Similarly for (n-k)!.)

    Rather than throw away the computations each time, we can iteratively compute nchoosek(n, k+1) based on the value of nchoosek(n, k).

    function L=combList(n, maxk)
    % Create a vector of length maxk containing
    %   [nchoosek(n, 1), nchoosek(n, 2), ..., nchoosek(n, maxk)]
    % Note: nchoosek(n, 0) == nchoosek(n, n) == 1
    assert(maxk<=n, 'maxk must be less than or equal to n');
    L = zeros(1,maxk);
    L(1) = n;                    % nchoosek(n, 1) == n
    for k = 2:maxk
       L(k) = L(k-1)*(n-k+1)/k;

    In your program, you would just create 3 lists for k1, k2, and k1+k2 with the appropriate limits, and then index into those lists to generate the sums.

    0 讨论(0)
  • 2020-12-20 09:55

    You can vectorize all the process and it will make it super-fast without any need for mex.

    First the nchoosek function:

    function C = nCk(n,k)
    % use smaller k if available
    k(k>n/2) = n-k(k>n/2);
    k = k(:);
    kmat = ones(numel(k),1)*(1:max(n-k));
    kmat = kmat.*bsxfun(@le,kmat,(n-k));
    pw = bsxfun(@power,kmat,-1./(n-k));
    pw(kmat==0) = 1;
    kleft = ones(numel(k),1)*(min(k):n);
    kleft = kleft.*bsxfun(@gt,kleft,k);
    t = bsxfun(@times,kleft,prod(pw,2));
    t (kleft==0) = 1;
    C = prod(t,2);

    Then beta and P computation:

    function P = binomial_coefficient(k1,k2,D)
    warning ('off','MATLAB:nchoosek:LargeCoefficient');
    i_ind = nonzeros(triu(ones(D,1)*(1:D)))-1;
    j_ind = nonzeros(tril(ones(D,1)*(1:D+1)).')-1;
    valid = ~(i_ind-j_ind>=k2 | j_ind>=k1);
    i_ind = i_ind(valid);
    j_ind = j_ind(valid);
    beta = @(ii,jj) nCk(k1,jj).*nCk(k2,ii-jj)./nCk((k1+k2),ii);
    b = beta(i_ind,j_ind);
    P = sum(b(:));

    and execution time drops from 10.674 to 0.49696 seconds.


    Taking the idea of @rahnema1, I managed to make this even faster, using a table for all unique nCk computations, so none of them will be done more than once. Using the same nCk function from above, this is how the new binomial_coefficient function will look:

    function P = binomial_coefficient(k1,k2,D)
    warning ('off','MATLAB:nchoosek:LargeCoefficient');
    i_ind = nonzeros(triu(ones(D,1)*(1:D)))-1;
    j_ind = nonzeros(tril(ones(D,1)*(1:D+1)).')-1;
    valid = ~(i_ind-j_ind>=k2 | j_ind>=k1);
    i_ind = i_ind(valid);
    j_ind = j_ind(valid);
    ni = numel(i_ind);
    all_n = repelem([k1; k2; k1+k2],ni); % all n's to be used in thier order
    all_k = [j_ind; i_ind-j_ind; i_ind]; % all k's to be used in thier order
    % get all unique sets of 'n' and 'k':
    sets_tbl = unique([all_n all_k],'rows');
    uq_n = unique(sets_tbl(:,1));
    nCk_tbl = zeros([max(all_n) max(all_k)+1]);
    % compute all the needed values of nCk:
    for s = 1:numel(uq_n)
        curret_n = uq_n(s);
        curret_k = sets_tbl(sets_tbl(:,1)==curret_n,2);
        nCk_tbl(curret_n,curret_k+1) = nCk(curret_n,curret_k).';
    beta = @(ii,jj) nCk_tbl(k1,jj+1).*nCk_tbl(k2,ii-jj+1)./nCk_tbl((k1+k2),ii+1);
    b = beta(i_ind,j_ind);
    P = sum(b(:));

    and now, when it takes only 0.01212 second to run, it's not just super-fast code, it's a flying-talking-super-fast code!

    0 讨论(0)
  • 2020-12-20 09:59

    You can save results of nchoosek to a table to prevent repeated evaluation of the function, also an implementation of binomial coefficients provided:

    %binomial coefficients
    function nk=nchoosek2(n, k)
        if n-k > k
            nk = prod((k+1:n) .* prod((1:n-k).^ (-1/(n-k))));
            nk = prod((n-k+1:n) .* prod((1:k).^ (-1/k)) ) ;
    %function to store and retrieve results of nchoosek to/from a table
    function ret = choose (n,k, D, K1, K2)
        persistent binTable = zeros(max([D+1,K1+K2+1]) , D+1);
        if binTable(n+1,k+1) == 0
            binTable(n+1,k+1) = nchoosek2(n,k);
        ret = binTable(n+1,k+1);
    function P = tst()
        P=0;k1=150; k2=150; D=200; P=0;
        for i = 0:D-1
            for j = j=max(i - k2 , 0):min (i,k1-1)

    Your code with nchoosek2 compared with this: online demo

    0 讨论(0)