Kth maximum sum of a contiguous subarray of +ve integers in O(nlogS)

前端 未结 3 1651
故里飘歌
故里飘歌 2021-01-13 14:43

I was reading this editorial and got confused with this statement:

If the array elements are all non-negative, we can use binary search to find the an

3条回答
  •  广开言路
    2021-01-13 15:18

    None of the existing answers are correct, so here is a correct approach.

    First of all as @PhamTrung pointed out, we can in O(n) generate the cumulative sums of the array, and by subtracting two cumulative sums we can calculate the cumulative sum of any contiguous subarray in O(1). At that point our answer is bounded between 0 and the sum S of everything.

    Next, we know how many contiguous subarrays there are. Just choose the endpoints, there are n*(n-1)/2) such pairs.

    The heart of the problem is that given X, we need to in O(n), count how many pairs are less than, m. To do that we use a pair of pointers, i and j. We run them up in parallel, keeping the sum from i to j below X but keeping them as far apart as possible given that. And then we keep adding how many pairs there were between them that would also be below X. In pseudocode that looks like this:

    count_below = 0
    i = 0
    j = -1
    while i < n:
        while j+1 < n or sum(from i to j+1) < X:
            count_below += 1 # Add the (i, j+1) interval
            j += 1
        if j+1 == n:
            count_below += (n-i-1) * (n-i-2) / 2 # Add all pairs with larger i
            i = n
        else:
            while X <= sum(from i+1 to j+1):
                i += 1
                count_below += j - i # Add a block of (i, ?) pairs
    

    I can't swear that I got the indexing right, but that's the idea. The tricky bit is that every time we advance j we only add one, but every time we advance i we include every (i, mid) with i < mid <= j.

    And now we do binary search on the value.

    lower = 0
    upper = S
    while lower < upper:
       mid = floor((upper + lower)/2)
       if count below mid < count_intervals - k:
           lower = mid+1
       else:
           upper = mid
    

    Assuming that the sums are integers, this will find the correct answer in O(log(S)) searches. Each of which is O(n). For a total time of O(n log(S)).

    Note that if we're clever about the binary searching and keep track of both the count and the closest two sums, we can improve the time to O(n log(min(n, S))) by dropping upper to the max sum <= mid and raising lower to the next higher sum. Drop the floor and that approach also will work with floating point numbers to produce an answer in O(n log(n)). (With S being effectively infinite.)

提交回复
热议问题