I was reading this editorial and got confused with this statement:
If the array elements are all non-negative, we can use binary search to find the an
None of the existing answers are correct, so here is a correct approach.
First of all as @PhamTrung pointed out, we can in O(n)
generate the cumulative sums of the array, and by subtracting two cumulative sums we can calculate the cumulative sum of any contiguous subarray in O(1)
. At that point our answer is bounded between 0
and the sum S
of everything.
Next, we know how many contiguous subarrays there are. Just choose the endpoints, there are n*(n-1)/2)
such pairs.
The heart of the problem is that given X
, we need to in O(n)
, count how many pairs are less than, m
. To do that we use a pair of pointers, i
and j
. We run them up in parallel, keeping the sum from i
to j
below X
but keeping them as far apart as possible given that. And then we keep adding how many pairs there were between them that would also be below X
. In pseudocode that looks like this:
count_below = 0
i = 0
j = -1
while i < n:
while j+1 < n or sum(from i to j+1) < X:
count_below += 1 # Add the (i, j+1) interval
j += 1
if j+1 == n:
count_below += (n-i-1) * (n-i-2) / 2 # Add all pairs with larger i
i = n
else:
while X <= sum(from i+1 to j+1):
i += 1
count_below += j - i # Add a block of (i, ?) pairs
I can't swear that I got the indexing right, but that's the idea. The tricky bit is that every time we advance j
we only add one, but every time we advance i
we include every (i, mid)
with i < mid <= j
.
And now we do binary search on the value.
lower = 0
upper = S
while lower < upper:
mid = floor((upper + lower)/2)
if count below mid < count_intervals - k:
lower = mid+1
else:
upper = mid
Assuming that the sums are integers, this will find the correct answer in O(log(S))
searches. Each of which is O(n)
. For a total time of O(n log(S))
.
Note that if we're clever about the binary searching and keep track of both the count and the closest two sums, we can improve the time to O(n log(min(n, S)))
by dropping upper
to the max sum <= mid
and raising lower
to the next higher sum. Drop the floor
and that approach also will work with floating point numbers to produce an answer in O(n log(n))
. (With S
being effectively infinite.)