How to implement segment trees with lazy propagation?

后端 未结 5 1334
刺人心
刺人心 2021-01-30 05:25

I have searched on internet about implementation of Segment trees but found nothing when it came to lazy propagation. There were some previous questions on stack overflow but th

5条回答
  •  深忆病人
    2021-01-30 06:27

    Although I haven't successfully solved it yet, I believe this problem is much easier than what we think. You probably don't even need to use Segment Tree/Interval Tree... In fact, I tried both ways of implementing Segment Tree, one uses tree structure and the other uses array, both solutions got TLE quickly. I have a feeling it could be done using Greedy, but I'm not sure yet. Anyway, if you want to see how things are done using Segment Tree, feel free to study my solution. Note that max_tree[1] and min_tree[1] are corresponding to max/min.

    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    
    #include 
    #include 
    #include 
    #include 
    #include 
    
    #ifdef _WIN32 || _WIN64
    #define getc_unlocked _fgetc_nolock
    #endif
    
    using namespace std;
    
    const int MAX_RANGE = 1000000;
    const int NIL = -(1 << 29);
    int data[MAX_RANGE] = {0};
    int min_tree[3 * MAX_RANGE + 1];
    int max_tree[3 * MAX_RANGE + 1];
    int added_to_interval[3 * MAX_RANGE + 1];
    
    struct node {
        int max_value;
        int min_value;
        int added;
        node *left;
        node *right;
    };
    
    node* build_tree(int l, int r, int values[]) {
        node *root = new node;
        root->added = 0;
        if (l > r) {
            return NULL;
        }
        else if (l == r) {
            root->max_value = l + 1; // or values[l]
            root->min_value = l + 1; // or values[l]
            root->added = 0;
            root->left = NULL;
            root->right = NULL;
            return root;
        }
        else {  
            root->left = build_tree(l, (l + r) / 2, values);
            root->right = build_tree((l + r) / 2 + 1, r, values);
            root->max_value = max(root->left->max_value, root->right->max_value);
            root->min_value = min(root->left->min_value, root->right->min_value);
            root->added = 0;
            return root;
        }
    }
    
    node* build_tree(int l, int r) {
        node *root = new node;
        root->added = 0;
        if (l > r) {
            return NULL;
        }
        else if (l == r) {
            root->max_value = l + 1; // or values[l]
            root->min_value = l + 1; // or values[l]
            root->added = 0;
            root->left = NULL;
            root->right = NULL;
            return root;
        }
        else {  
            root->left = build_tree(l, (l + r) / 2);
            root->right = build_tree((l + r) / 2 + 1, r);
            root->max_value = max(root->left->max_value, root->right->max_value);
            root->min_value = min(root->left->min_value, root->right->min_value);
            root->added = 0;
            return root;
        }
    }
    
    void update_tree(node* root, int begin, int end, int i, int j, int amount) {
        // out of range
        if (begin > end || begin > j || end < i) {
            return;
        }
        // in update range (i, j)
        else if (i <= begin && end <= j) {
            root->max_value += amount;
            root->min_value += amount;
            root->added += amount;
        }
        else {
            if (root->left == NULL && root->right == NULL) {
                root->max_value = root->max_value + root->added;
                root->min_value = root->min_value + root->added;
            }
            else if (root->right != NULL && root->left == NULL) {
                update_tree(root->right, (begin + end) / 2 + 1, end, i, j, amount);
                root->max_value = root->right->max_value + root->added;
                root->min_value = root->right->min_value + root->added;
            }
            else if (root->left != NULL && root->right == NULL) {
                update_tree(root->left, begin, (begin + end) / 2, i, j, amount);
                root->max_value = root->left->max_value + root->added;
                root->min_value = root->left->min_value + root->added;
            }
            else {
                update_tree(root->right, (begin + end) / 2 + 1, end, i, j, amount);
                update_tree(root->left, begin, (begin + end) / 2, i, j, amount);
                root->max_value = max(root->left->max_value, root->right->max_value) + root->added;
                root->min_value = min(root->left->min_value, root->right->min_value) + root->added;
            }
        }
    }
    
    void print_tree(node* root) {
        if (root != NULL) {
            print_tree(root->left);
            cout << "\t(max, min): " << root->max_value << ", " << root->min_value << endl;
            print_tree(root->right);
        }
    }
    
    void clean_up(node*& root) {
        if (root != NULL) {
            clean_up(root->left);
            clean_up(root->right);
            delete root;
            root = NULL;
        }
    }
    
    void update_bruteforce(int x, int y, int z, int &smallest, int &largest, int data[], int n) {
        for (int i = x; i <= y; ++i) {
            data[i] += z;       
        }
    
        // update min/max
        smallest = data[0];
        largest = data[0];
        for (int i = 0; i < n; ++i) {
            if (data[i] < smallest) {
                smallest = data[i];
            }
    
            if (data[i] > largest) {
                largest = data[i];
            }
        }
    }
    
    void build_tree_as_array(int position, int left, int right) {
        if (left > right) {
            return;
        }
        else if (left == right) {
            max_tree[position] = left + 1;
            min_tree[position] = left + 1;
            added_to_interval[position] = 0;
            return;
        }
        else {
            build_tree_as_array(position * 2, left, (left + right) / 2);
            build_tree_as_array(position * 2 + 1, (left + right) / 2 + 1, right);
            max_tree[position] = max(max_tree[position * 2], max_tree[position * 2 + 1]);
            min_tree[position] = min(min_tree[position * 2], min_tree[position * 2 + 1]);
        }
    }
    
    void update_tree_as_array(int position, int b, int e, int i, int j, int value) {
        if (b > e || b > j || e < i) {
            return;
        }
        else if (i <= b && e <= j) {
            max_tree[position] += value;
            min_tree[position] += value;
            added_to_interval[position] += value;
            return;
        }
        else {
            int left_branch = 2 * position;
            int right_branch = 2 * position + 1;
            // make sure the array is ok
            if (left_branch >= 2 * MAX_RANGE + 1 || right_branch >= 2 * MAX_RANGE + 1) {
                max_tree[position] = max_tree[position] + added_to_interval[position];
                min_tree[position] = min_tree[position] + added_to_interval[position];
                return;
            }
            else if (max_tree[left_branch] == NIL && max_tree[right_branch] == NIL) {
                max_tree[position] = max_tree[position] + added_to_interval[position];
                min_tree[position] = min_tree[position] + added_to_interval[position];
                return;
            }
            else if (max_tree[left_branch] != NIL && max_tree[right_branch] == NIL) {
                update_tree_as_array(left_branch, b , (b + e) / 2 , i, j, value);
                max_tree[position] = max_tree[left_branch] + added_to_interval[position];
                min_tree[position] = min_tree[left_branch] + added_to_interval[position];
            }
            else if (max_tree[right_branch] != NIL && max_tree[left_branch] == NIL) {
                update_tree_as_array(right_branch, (b + e) / 2 + 1 , e , i, j, value);
                max_tree[position] = max_tree[right_branch] + added_to_interval[position];
                min_tree[position] = min_tree[right_branch] + added_to_interval[position];
            }
            else {
                update_tree_as_array(left_branch, b, (b + e) / 2 , i, j, value);
                update_tree_as_array(right_branch, (b + e) / 2 + 1 , e , i, j, value);
                max_tree[position] = max(max_tree[position * 2], max_tree[position * 2 + 1]) + added_to_interval[position]; 
                min_tree[position] = min(min_tree[position * 2], min_tree[position * 2 + 1]) + added_to_interval[position];
            }
        }
    }
    
    void show_data(int data[], int n) {
        cout << "[current data]\n";
        for (int i = 0; i < n; ++i) {
            cout << data[i] << ", ";
        }
        cout << endl;
    }
    
    inline void input(int* n) {
        char c = 0;
        while (c < 33) {
            c = getc_unlocked(stdin);
        }
    
        *n = 0;
        while (c > 33) {
            *n = (*n * 10) + c - '0';
            c = getc_unlocked(stdin);
        }
    }
    
    void handle_special_case(int m) {
        int type;
        int x;
        int y;
        int added_amount;
        for (int i = 0; i < m; ++i) {
            input(&type);
            input(&x);
            input(&y);
            input(&added_amount);
        }
        printf("0\n");
    }
    
    void find_largest_range_use_tree() {
        int n;
        int m;
        int type;
        int x;
        int y;
        int added_amount;
    
        input(&n);
        input(&m);
    
        if (n == 1) {
            handle_special_case(m);
            return;
        }
    
        node *root = build_tree(0, n - 1);
        for (int i = 0; i < m; ++i) {
            input(&type);
            input(&x);
            input(&y);
            input(&added_amount);
            if (type == 1) {    
                added_amount *= 1;
            }
            else {
                added_amount *= -1;
            }
    
            update_tree(root, 0, n - 1, x - 1, y - 1, added_amount);
        }
    
        printf("%d\n", root->max_value - root->min_value);
    }
    
    void find_largest_range_use_array() {
        int n;
        int m;
        int type;
        int x;
        int y;
        int added_amount;
    
        input(&n);
        input(&m);
    
        if (n == 1) {
            handle_special_case(m);
            return;
        }
    
        memset(min_tree, NIL, 3 * sizeof(int) * n + 1);
        memset(max_tree, NIL, 3 * sizeof(int) * n + 1);
        memset(added_to_interval, 0, 3 * sizeof(int) * n + 1);
        build_tree_as_array(1, 0, n - 1);
    
        for (int i = 0; i < m; ++i) {
            input(&type);
            input(&x);
            input(&y);
            input(&added_amount);
            if (type == 1) {    
                added_amount *= 1;
            }
            else {
                added_amount *= -1;
            }
    
            update_tree_as_array(1, 0, n - 1, x - 1, y - 1, added_amount);
        }
    
        printf("%d\n", max_tree[1] - min_tree[1]);
    }
    
    void update_slow(int x, int y, int value) {
        for (int i = x - 1; i < y; ++i) {
            data[i] += value;
        }
    }
    
    void find_largest_range_use_common_sense() {
        int n;
        int m;
        int type;
        int x;
        int y;
        int added_amount;
    
        input(&n);
        input(&m);
    
        if (n == 1) {
            handle_special_case(m);
            return;
        }
    
        memset(data, 0, sizeof(int) * n);
        for (int i = 0; i < m; ++i) {
            input(&type);
            input(&x);
            input(&y);
            input(&added_amount);
    
            if (type == 1) {    
                added_amount *= 1;
            }
            else {
                added_amount *= -1;
            }
    
            update_slow(x, y, added_amount);
        }
    
         // update min/max
        int smallest = data[0] + 1;
        int largest = data[0] + 1;
        for (int i = 1; i < n; ++i) {
            if (data[i] + i + 1 < smallest) {
                smallest = data[i] + i + 1;
            }
    
            if (data[i] + i + 1 > largest) {
                largest = data[i] + i + 1;
            }
        }
    
        printf("%d\n", largest - smallest); 
    }
    
    void inout_range_of_data() {
        int test_cases;
        input(&test_cases);
    
        while (test_cases--) {
            find_largest_range_use_common_sense();
        }
    }
    
    namespace unit_test {
        void test_build_tree() {
            for (int i = 0; i < MAX_RANGE; ++i) {
                data[i] = i + 1;
            }
    
            node *root = build_tree(0, MAX_RANGE - 1, data);
            print_tree(root);
        }
    
        void test_against_brute_force() {
              // arrange
            int number_of_operations = 100;
            for (int i = 0; i < MAX_RANGE; ++i) {
                data[i] = i + 1;
            }
    
            node *root = build_tree(0, MAX_RANGE - 1, data);
    
            // print_tree(root);
            // act
            int operation;
            int x;
            int y;
            int added_amount;
            int smallest = 1;
            int largest = MAX_RANGE;
    
            // assert
            while (number_of_operations--) {
                operation = rand() % 2; 
                x = 1 + rand() % MAX_RANGE;
                y = x + (rand() % (MAX_RANGE - x + 1));
                added_amount = 1 + rand() % MAX_RANGE;
                // cin >> operation >> x >> y >> added_amount;
                if (operation == 1) {
                    added_amount *= 1;
                }
                else {
                    added_amount *= -1;    
                }
    
                update_bruteforce(x - 1, y - 1, added_amount, smallest, largest, data, MAX_RANGE);
                update_tree(root, 0, MAX_RANGE - 1, x - 1, y - 1, added_amount);
                assert(largest == root->max_value);
                assert(smallest == root->min_value);
                for (int i = 0; i < MAX_RANGE; ++i) {
                    cout << data[i] << ", ";
                }
                cout << endl << endl;
                cout << "correct:\n";
                cout << "\t largest = " << largest << endl;
                cout << "\t smallest = " << smallest << endl;
                cout << "testing:\n";
                cout << "\t largest = " << root->max_value << endl;
                cout << "\t smallest = " << root->min_value << endl;
                cout << "testing:\n";
                cout << "\n------------------------------------------------------------\n";
                cout << "final result: " << largest - smallest << endl;
                cin.get();
            }
    
            clean_up(root);
        }
    
        void test_automation() {
              // arrange
            int test_cases;
            int number_of_operations = 100;
            int n;
    
    
            test_cases = 10000;
            for (int i = 0; i < test_cases; ++i) {
                n = i + 1;
    
                int operation;
                int x;
                int y;
                int added_amount;
                int smallest = 1;
                int largest = n;
    
    
                // initialize data for brute-force
                for (int i = 0; i < n; ++i) {
                    data[i] = i + 1;
                }
    
                // build tree   
                node *root = build_tree(0, n - 1, data);
                for (int i = 0; i < number_of_operations; ++i) {
                    operation = rand() % 2; 
                    x = 1 + rand() % n;
                    y = x + (rand() % (n - x + 1));
                    added_amount = 1 + rand() % n;
    
                    if (operation == 1) {
                        added_amount *= 1;
                    }
                    else {
                        added_amount *= -1;    
                    }
    
                    update_bruteforce(x - 1, y - 1, added_amount, smallest, largest, data, n);
                    update_tree(root, 0, n - 1, x - 1, y - 1, added_amount);
                    assert(largest == root->max_value);
                    assert(smallest == root->min_value);
    
                    cout << endl << endl;
                    cout << "For n = " << n << endl;
                    cout << ", where data is : \n";
                    for (int i = 0; i < n; ++i) {
                        cout << data[i] << ", ";
                    }
                    cout << endl;
                    cout << " and query is " << x - 1 << ", " << y - 1 << ", " << added_amount << endl;
                    cout << "correct:\n";
                    cout << "\t largest = " << largest << endl;
                    cout << "\t smallest = " << smallest << endl;
                    cout << "testing:\n";
                    cout << "\t largest = " << root->max_value << endl;
                    cout << "\t smallest = " << root->min_value << endl;
                    cout << "\n------------------------------------------------------------\n";
                    cout << "final result: " << largest - smallest << endl;
                }
    
                clean_up(root);
            }
    
            cout << "DONE............\n";
        }
    
        void test_tree_as_array() {
              // arrange
            int test_cases;
            int number_of_operations = 100;
            int n;
            test_cases = 1000;
            for (int i = 0; i < test_cases; ++i) {
                n = MAX_RANGE;
                memset(min_tree, NIL, sizeof(min_tree));
                memset(max_tree, NIL, sizeof(max_tree));
                memset(added_to_interval, 0, sizeof(added_to_interval));
                memset(data, 0, sizeof(data));
    
                int operation;
                int x;
                int y;
                int added_amount;
                int smallest = 1;
                int largest = n;
    
    
                // initialize data for brute-force
                for (int i = 0; i < n; ++i) {
                    data[i] = i + 1;
                }
    
                // build tree using array
                build_tree_as_array(1, 0, n - 1);
                for (int i = 0; i < number_of_operations; ++i) {
                    operation = rand() % 2; 
                    x = 1 + rand() % n;
                    y = x + (rand() % (n - x + 1));
                    added_amount = 1 + rand() % n;
    
                    if (operation == 1) {
                        added_amount *= 1;
                    }
                    else {
                        added_amount *= -1;    
                    }
    
                    update_bruteforce(x - 1, y - 1, added_amount, smallest, largest, data, n);
                    update_tree_as_array(1, 0, n - 1, x - 1, y - 1, added_amount);
                    //assert(max_tree[1] == largest);
                    //assert(min_tree[1] == smallest);
    
                    cout << endl << endl;
                    cout << "For n = " << n << endl;
                    // show_data(data, n);
                    cout << endl;
                    cout << " and query is " << x - 1 << ", " << y - 1 << ", " << added_amount << endl;
                    cout << "correct:\n";
                    cout << "\t largest = " << largest << endl;
                    cout << "\t smallest = " << smallest << endl;
                    cout << "testing:\n";
                    cout << "\t largest = " << max_tree[1] << endl;
                    cout << "\t smallest = " << min_tree[1] << endl;
                    cout << "\n------------------------------------------------------------\n";
                    cout << "final result: " << largest - smallest << endl;
                    cin.get();
                }
            }
    
            cout << "DONE............\n";
        }
    }
    
    int main() {
        // unit_test::test_against_brute_force();
        // unit_test::test_automation();    
        // unit_test::test_tree_as_array();
        inout_range_of_data();
    
        return 0;
    }
    

提交回复
热议问题