Dynamic prefix sum

匆匆过客 提交于 2020-01-02 05:37:05

问题


Is there any data structure which is able to return the prefix sum [1] of array, update an element, and insert/remove elements to the array, all in O(log n)?

[1] "prefix sum" is the sum of all elements from the first one up to given index

For example, given the array of non-negative integers 8 1 10 7 the prefix sum for first three elements is 19 (8 + 1 + 10). Updating the first element to 7, inserting 3 as the second element and removing the third one gives 7 3 10 7. Again, the prefix sum of first three elements would be 20.

For prefix sum and update, there is Fenwick tree. But I don't know how to handle the addition/removal in O(log n) with it.

On the other hand, there are several binary search trees such as Red-black tree, all of which handle the update/insert/remove in logarithmic time. But I don't know how to maintain the given ordering and do the prefix sum in O(log n).


回答1:


A treap with implicit keys can perform all this operations in O(log n) time per query. The idea of implicit keys is pretty simple: we do not store any keys in nodes. Instead, we maintain subtrees' sizes for all nodes and find an appropriate position when we add or remove an element using this information.

Here is my implementation:

#include <iostream>
#include <memory>

struct Node {
  int priority;
  int val;
  long long sum;
  int size;
  std::shared_ptr<Node> left;
  std::shared_ptr<Node> right;

  Node(long val): 
    priority(rand()), val(val), sum(val), size(1), left(), right() {}
};

// Returns the size of a node owned by t if it is not empty and 0 otherwise.
int getSize(std::shared_ptr<Node> t) {
  if (!t)
    return 0;
  return t->size;
}

// Returns the sum of a node owned by t if it is not empty and 0 otherwise.
long long getSum(std::shared_ptr<Node> t) {
  if (!t)
    return 0;
  return t->sum;
}


// Updates a node owned by t if it is not empty.
void update(std::shared_ptr<Node> t) {
  if (t) {
    t->size = 1 + getSize(t->left) + getSize(t->right);
    t->sum = t->val + getSum(t->left) + getSum(t->right);
  }
}

// Merges the nodes owned by L and R and returns the result.
std::shared_ptr<Node> merge(std::shared_ptr<Node> L, 
    std::shared_ptr<Node> R) {
  if (!L || !R)
    return L ? L : R;
  if (L->priority > R->priority) {
    L->right = merge(L->right, R);
    update(L);
    return L;
  } else {
    R->left = merge(L, R->left);
    update(R);
    return R;
  }
}

// Splits a subtree rooted in t by pos. 
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> split(
    std::shared_ptr<Node> t,
    int pos, int add) {
  if (!t)
    return make_pair(std::shared_ptr<Node>(), std::shared_ptr<Node>());
  int cur = getSize(t->left) + add;
  std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> res;
  if (pos <= cur) {
    auto ret = split(t->left, pos, add);
    t->left = ret.second;
    res = make_pair(ret.first, t); 
  } else {
    auto ret = split(t->right, pos, cur + 1);
    t->right = ret.first;
    res = make_pair(t, ret.second); 
  }
  update(t);
  return res;
}

// Returns a prefix sum of [0 ... pos]
long long getPrefixSum(std::shared_ptr<Node>& root, int pos) {
  auto parts = split(root, pos + 1, 0);
  long long res = getSum(parts.first);
  root = merge(parts.first, parts.second);
  return res;
}

// Adds a new element at a position pos with a value newValue.
// Indices are zero-based.
void addElement(std::shared_ptr<Node>& root, int pos, int newValue) {
  auto parts = split(root, pos, 0);
  std::shared_ptr<Node> newNode = std::make_shared<Node>(newValue);
  auto temp = merge(parts.first, newNode);
  root = merge(temp, parts.second);
}

// Removes an element at the given position pos.
// Indices are zero-based.
void removeElement(std::shared_ptr<Node>& root, int pos) {
  auto parts1 = split(root, pos, 0);
  auto parts2 = split(parts1.second, 1, 0);
  root = merge(parts1.first, parts2.second);
}

int main() {
  std::shared_ptr<Node> root;
  int n;
  std::cin >> n;
  for (int i = 0; i < n; i++) {
    std::string s;
    std::cin >> s;
    if (s == "add") {
      int pos, val;
      std::cin >> pos >> val;
      addElement(root, pos, val);
    } else if (s == "remove") {
      int pos;
      std::cin >> pos;
      removeElement(root, pos);
    } else {
      int pos;
      std::cin >> pos;
      std::cout << getPrefixSum(root, pos) << std::endl;
    }
  }
  return 0;
}



回答2:


An idea: to modify an AVL tree. Additions and deletions are done by index. Every node keeps the count and the sum of each subtree to allow all operations in O(log n).

Proof-of-concept with add_node and update_node and prefix_sum implemented:

class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
        self.left_height = 0
        self.right_height = 0
        self.left_count = 1
        self.left_sum = value
        self.right_count = 0
        self.right_sum = 0

    def set_value(self, value):
        self.value = value
        self.left_sum = self.left.left_sum + self.left.right_sum+self.value if self.left else self.value

    def set_left(self, node):
        self.left = node
        self.left_height = max(node.left_height, node.right_height)+1 if node else 0
        self.left_count = node.left_count + node.right_count+1 if node else 1
        self.left_sum = node.left_sum + node.right_sum+self.value if node else self.value

    def set_right(self, node):
        self.right = node
        self.right_height = max(node.left_height, node.right_height)+1 if node else 0
        self.right_count = node.left_count + node.right_count if node else 0
        self.right_sum = node.left_sum + node.right_sum if node else 0

    def rotate_left(self):
        b = self.right
        self.set_right(b.left)
        b.set_left(self)
        return b

    def rotate_right(self):
        a = self.left
        self.set_left(a.right)
        a.set_right(self)
        return a

    def factor(self):
        return self.right_height - self.left_height

def add_node(root, index, node):
    if root is None: return node

    if index < root.left_count:
        root.set_left(add_node(root.left, index, node))
        if root.factor() < -1:
            if root.left.factor() > 0:
                root.set_left(root.left.rotate_left())
            return root.rotate_right()
    else:
        root.set_right(add_node(root.right, index-root.left_count, node))
        if root.factor() > 1:
            if root.right.factor() < 0:
                root.set_right(root.right.rotate_right())
            return root.rotate_left()

    return root

def update_node(root, index, value):
    if root is None: return root

    if index+1 < root.left_count:
        root.set_left(update_node(root.left, index, value))
    elif index+1 > root.left_count:
        root.set_right(update_node(root.right, index - root.left_count, value))
    else:
        root.set_value(value)

    return root


def prefix_sum(root, index):
    if root is None: return 0

    if index+1 < root.left_count:
        return prefix_sum(root.left, index)
    else:
        return root.left_sum + prefix_sum(root.right, index-root.left_count)


import random
tree = None
tree = add_node(tree, 0, Node(10))
tree = add_node(tree, 1, Node(40))
tree = add_node(tree, 1, Node(20))
tree = add_node(tree, 2, Node(70))

tree = update_node(tree, 2, 30)

print prefix_sum(tree, 0)
print prefix_sum(tree, 1)
print prefix_sum(tree, 2)
print prefix_sum(tree, 3)
print prefix_sum(tree, 4)


来源:https://stackoverflow.com/questions/27990143/dynamic-prefix-sum

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!