hdu3911 线段树的区间更新

╄→尐↘猪︶ㄣ 提交于 2020-01-05 03:18:18

  题目链接如下:http://acm.hdu.edu.cn/showproblem.php?pid=3911   大意是给你一个01串, 以及两种操作, 第一种操作是询问区间内连续的1的个数, 第二种操作是翻转一个区间内的0 和 1, 我们直接在维护一个区间左端开始连续的0 和 1的个数, 右端开始连续的0 和 1的个数, 以及当前区间连续的1 和 0的最大数量, 即可 , 在pushup更新的时候应该注意, 当前区间的最大的连续的1的数量 = 左子树最大的1的数量 右子树最大的1的数量 以及中间的最大的连续的1的数量 (这里wa了好久), 在query的时候应该注意控制区间长度。 代码如下:

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;
const int maxn = 100000 + 100;
int n;
int a[maxn];
struct Segment{
    int lsum0, lsum1, rsum0, rsum1;
    int msum0, msum1;
    int l, r;
    int ck;   //翻转标记
}tree[3*maxn];

void push_up(int rt){   //回溯的时候利用儿子结点信息更新父亲
    int ll = tree[rt<<1].r - tree[rt<<1].l + 1;//求左子树的线段长度
    int rl = tree[rt<<1|1].r - tree[rt<<1|1].l + 1;//求右子树的线段长度


    int chl = 2*rt, chr = 2*rt+1;
//    int tpmsum0 = max(tree[chl].lsum0, tree[chr].rsum0);
//    tree[rt].msum0 = max(tpmsum0, tree[chl].rsum0+tree[chr].lsum0);   //这种写法是错误的 注意
    tree[rt].msum0 = max( (tree[rt<<1].rsum0 + tree[rt<<1|1].lsum0) , max(tree[rt<<1].msum0,tree[rt<<1|1].msum0));

    int tpmsum1 = max(tree[chl].msum1, tree[chr].msum1);
    tree[rt].msum1 = max(tpmsum1, tree[chl].rsum1+tree[chr].lsum1);
//    tree[rt].msum1 = max((tree[rt<<1].rsum1 + tree[rt<<1|1].lsum1) , max(tree[rt<<1].msum1,tree[rt<<1|1].msum1));

    tree[rt].lsum0 = tree[chl].lsum0;
    if(tree[chl].lsum0 == tree[chl].r-tree[chl].l+1) tree[rt].lsum0 += tree[chr].lsum0;

    tree[rt].lsum1 = tree[chl].lsum1;
    if(tree[chl].lsum1 == tree[chl].r-tree[chl].l+1) tree[rt].lsum1 += tree[chr].lsum1;

    tree[rt].rsum0 = tree[chr].rsum0;
    if(tree[chr].rsum0 == tree[chr].r-tree[chr].l+1) tree[rt].rsum0 += tree[chl].rsum0;

    tree[rt].rsum1 = tree[chr].rsum1;
    if(tree[chr].rsum1 == tree[chr].r-tree[chr].l+1) tree[rt].rsum1 += tree[chl].rsum1; //我的内心是崩溃的
}

void build(int rt, int l, int r){
    tree[rt].l = l; tree[rt].r = r;
    tree[rt].ck = 0;
    if(l == r){
        if(a[l] == 0){
            tree[rt].lsum0 = tree[rt].rsum0 = 1;
            tree[rt].lsum1 = tree[rt].rsum1 = 0;
            tree[rt].msum0 = 1; tree[rt].msum1 = 0;
        }else{
            tree[rt].lsum0 = tree[rt].rsum0 = 0;
            tree[rt].lsum1 = tree[rt].rsum1 = 1;
            tree[rt].msum0 = 0; tree[rt].msum1 = 1;
        }
        return ;
    }
    int mid = (l+r)/2;
    build(2*rt, l, mid);
    build(2*rt+1, mid+1, r);
    push_up(rt);            //回溯时利用儿子结点的信息更新父亲结点
}

void push_down(int rt){
    if(tree[rt].ck){
        int chl = 2*rt, chr = 2*rt+1;
        swap(tree[chl].lsum0, tree[chl].lsum1);
        swap(tree[chl].rsum0, tree[chl].rsum1);
        swap(tree[chl].msum0, tree[chl].msum1);
//        tree[chl].ck ^= 1;
        tree[chl].ck = !tree[chl].ck;
        swap(tree[chr].lsum0, tree[chr].lsum1);
        swap(tree[chr].rsum0, tree[chr].rsum1);
        swap(tree[chr].msum0, tree[chr].msum1);
//        tree[chr].ck ^= 1;
        tree[chr].ck = !tree[chr].ck;
//        tree[rt].ck ^= 1;
        //tree[rt].ck = !tree[rt].ck;
        tree[rt].ck = 0;
    }
}

void update(int rt, int l, int r){   //l - r都翻转 0 -> 1 1 -> 0
    if(tree[rt].l==l && tree[rt].r==r){
        swap(tree[rt].lsum0, tree[rt].lsum1);
        swap(tree[rt].rsum0, tree[rt].rsum1);
        swap(tree[rt].msum0, tree[rt].msum1);
//        tree[rt].ck ^= 1;
        tree[rt].ck = !tree[rt].ck;
        return ;
    }
    push_down(rt);
    int mid = (tree[rt].l + tree[rt].r) / 2;
    if(r <= mid)
        update(2*rt, l, r);
    else if(l > mid)
        update(2*rt+1, l, r);
    else{
        update(2*rt, l, mid);
        update(2*rt+1, mid+1, r);
    }
    push_up(rt);
}

int query(int rt, int l, int r){   //查询 l - r的连续的1的个数
    if(tree[rt].l==l && tree[rt].r==r)
        return tree[rt].msum1;
    push_down(rt);
    int res;
    int mid = (tree[rt].l + tree[rt].r)/2;
    if(r <= mid)
        res =  query(2*rt, l, r);
    else if(l > mid)
        res =  query(2*rt+1, l, r);
    else{
        int v1 = query(2*rt, l, mid);
        int v2 = query(2*rt+1, mid+1, r);
        int v3 = min(tree[2*rt].rsum1, mid-l+1) + min(tree[2*rt+1].lsum1, r-mid-1+1);
        int v4 = max(v1, v2);
        res = max(v4, v3);
    }
    push_up(rt);
    return res;
}

int main() {
    while(scanf("%d", &n) != EOF){
        for(int i=1; i<=n; i++)
            scanf("%d", &a[i]);
        build(1, 1, n);
        int m;
        scanf("%d", &m);
        for(int i=0; i<m; i++){
            int x, l, r;
            scanf("%d%d%d", &x, &l, &r);
            if(x == 0)
                printf("%d\n", query(1, l, r));
            else
                update(1, l, r);
        }
    }
    return 0;
}

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!