How to find all ordered pairs of elements in array of integers whose sum lies in a given range of value

前端 未结 8 1030
抹茶落季
抹茶落季 2021-02-05 21:58

Given an array of integers find the number of all ordered pairs of elements in the array whose sum lies in a given range [a,b]

Here is an O(n^2) solution for the same <

8条回答
  •  抹茶落季
    2021-02-05 21:59

    The problem of counting the pairs that work can be done in sort time + O(N). This is faster than the solution that Ani gives, which is sort time + O(N log N). The idea goes like this. First you sort. You then run nearly the same single pass algorithm twice. You then can use the results of the two single pass algorithms to calculate the answer.

    The first time we run the single pass algorithm, we will create a new array that lists the smallest index that can partner with that index to give a sum greater than a. Example:

    a = 6
    array = [-20, 1, 3, 4, 8, 11]
    output = [6, 4, 2, 2, 1, 1]
    

    So, the number at array index 1 is 1 (0 based indexing). The smallest number it can pair with to get over 6 is the eight, which is at index 4. Hence output[1] = 4. -20 can't pair with anything, so output[0] = 6 (out of bounds). Another example: output[4] = 1, because 8 (index 4) can pair with the 1 (index 1) or any number after it to sum more than 6.

    What you need to do now is convince yourself that this is O(N). It is. The code is:

    i, j = 0, 5
    while i - j <= 0:
      if array[i] + array[j] >= a:
        output[j] = i
        j -= 1
      else:
        output[i] = j + 1
        i += 1
    

    Just think of two pointers starting at the edges and working inwards. It's O(N). You now do the same thing, just with the condition b <= a:

    while i-j <= 0:
      if array[i] + array[j] <= b:
        output2[i] = j
        i += 1
      else:
        output2[j] = i-1
        j-=1
    

    In our example, this code gives you (array and b for reference):

    b = 9
    array = [-20, 1, 3, 4, 8, 11]
    output2 = [5, 4, 3, 3, 1, 0]
    

    But now, output and output2 contain all the information we need, because they contain the range of valid indices for pairings. output is the smallest index it can be paired with, output2 is the largest index it can be paired with. The difference + 1 is the number of pairings for that location. So for the first location (corresponding to -20), there are 5 - 6 + 1 = 0 pairings. For 1, there are 4-4 + 1 pairings, with the number at index 4 which is 8. Another subtlety, this algo counts self pairings, so if you don't want it, you have to subtract. E.g. 3 seems to contain 3-2 + 1 = 2 pairings, one at index 2 and one at index 3. Of course, 3 itself is at index 2, so one of those is the self pairing, the other is the pairing with 4. You just need to subtract one whenever the range of indices of output and output2 contain the index itself you're looking at. In code, you can write:

    answer = [o2 - o + 1 - (o <= i <= o2) for i, (o, o2) in enumerate(zip(output, output2))]
    

    Which yields:

    answer = [0, 1, 1, 1, 1, 0]
    

    Which sums to 4, corresponding to (1,8), (3,4), (4,3), (8, 1)

    Anyhow, as you can see, this is sort + O(N), which is optimal.

    Edit: asked for full implementation. Provided. For reference, the full code:

    def count_ranged_pairs(x, a, b):
        x.sort()
    
        output = [0] * len(x)
        output2 = [0] * len(x)
    
        i, j = 0, len(x)-1
        while i - j <= 0:
          if x[i] + x[j] >= a:
            output[j] = i
            j -= 1
          else:
            output[i] = j + 1
            i += 1
    
        i, j = 0, len(x) - 1
        while i-j <= 0:
          if x[i] + x[j] <= b:
            output2[i] = j
            i += 1
          else:
            output2[j] = i-1
            j -=1
    
        answer = [o2 - o + 1 - (o <= i <= o2) for i, (o, o2) in enumerate(zip(output, output2))]
        return sum(answer)/2
    

提交回复
热议问题