sum of absolute differences of a number in an array

前端 未结 2 2003
野的像风
野的像风 2021-02-04 05:56

I want to calculate the sum of absolute differences of a number at index i with all integers up to index i-1 in o(n). But i am not able to think of any approach better than o(n^

2条回答
  •  深忆病人
    2021-02-04 06:39

    I can offer an O(n log n) solution for a start: Let fi be the i-th number of the result. We have:

    enter image description here

    When walking through the array from left to right and maintain a binary search tree of the elements a0 to ai-1, we can solve all parts of the formula in O(log n):

    • Keep subtree sizes to count the elements larger than/smaller than a given one
    • Keep cumulative subtree sums to answer the sum queries for elements larger than/smaller than a given one

    We can replace the augmented search tree with some simpler data structures if we want to avoid the implementation cost:

    • Sort the array beforehand. Assign every number its rank in the sorted order
    • Keep a binary indexed tree of 0/1 values to calculate the number of elements smaller than a given value
    • Keep another binary indexed tree of the array values to calculate the sums of elements smaller than a given value

    TBH I don't think this can be solved in O(n) in the general case. At the very least you would need to sort the numbers at some point. But maybe the numbers are bounded or you have some other restriction, so you might be able to implement the sum and count operations in O(1).

    An implementation:

    # binary-indexed tree, allows point updates and prefix sum queries
    class Fenwick:
      def __init__(self, n):
        self.tree = [0]*(n+1)
        self.n = n
      def update_point(self, i, val):  # O(log n)
        i += 1
        while i <= self.n:
          self.tree[i] += val
          i += i & -i
      def read_prefix(self, i):        # O(log n)
        i += 1
        sum = 0
        while i > 0:
          sum += self.tree[i]
          i -= i & -i
        return sum
    
    def solve(a):
      rank = { v : i for i, v in enumerate(sorted(a)) }
      res = []
      counts, sums = Fenwick(len(a)), Fenwick(len(a))
      total_sum = 0
      for i, x in enumerate(a):
        r = rank[x]
        num_smaller = counts.read_prefix(r)
        sum_smaller = sums.read_prefix(r)
        res.append(total_sum - 2*sum_smaller + x * (2*num_smaller - i))
        counts.update_point(r, 1)
        sums.update_point(r, x)
        total_sum += x
      return res
    
    print(solve([3,5,6,7,1]))  # [0, 2, 4, 7, 17]
    print(solve([2,0,1]))      # [0, 2, 2]
    

提交回复
热议问题