树状数组的整理

依然范特西╮ 提交于 2020-02-27 02:53:43

* 如m = 11000, 则C[m] = C[10100] + C[10110] + C[10111] + A[11000];
    则S[m] = C[11000] + C[10000];
 
1.区间求和
    向上更新每一个父节点,向下统计每一个子节点之和;
 
2.查询单点
    向上更新区间(update(l,1) /*以左端点为起点++*/,update(r+1,-1)/*以右端点为起--*/),向下统计子节点之和;
    反过来,向下更新向上统计也可以;
 
* 1.树状数组的起点从1开始,到最大值结束,因此终点n不是个数而是最大值;
    2.空间复杂度为N,即数组大小为N;
 
代码模版:
void update(int pos,int val)
{
   while(pos <= n)
    {
        c[pos] += val;
        pos += lowbit(pos);
    }
}

int sum(int end)
{
    int ret = 0;
    while(end > 0)
    {
        ret += c[end];
        end -= lowbit(end);
    }
    return ret;
}
 
3.求逆序数
    
  对于一个序列求每个数前面比它大或小的数的个数的总和,将数字离散化得到大小关系用树状数组求和;
  /*离散化:当数据只与它们之间的相对大小有关,而与具体是多少无关时,可以进行离散化。*/
代码模版:
       for(int i=0;i<n;i++)
        {
            scanf("%d",&arr[i].val);
            arr[i].pos = i;
        }
        sort(arr,arr+n,cmp);
        for(int i=0;i<n;i++)
            reflect[arr[i].pos] = i;
        for(int i=0;i<n;i++)
        {
            ans += sum(++reflect[i]);
            update(reflect[i]);
        }

 

4.二维树状数组:
    原理与一维相同,将数组变为二维数组;
 
代码模版:
void update(int x,int y,int val)
{
    for(int i = x; i <= n; i += lowbit(i) )
        for(int j = y; j <= n; j += lowbit(j) )
            c[i][j] += val;
}
 
int sum(int x,int y)
{
    int ret = 0;
    for(int i = x; i <= n; i -= lowbit(i) )
        for(int j = y; j <= n; j -= lowbit(j) )
            ret += c[i][j];
    return ret;
}

 

 

树状数组相关题集:

 

HDU1556
输入N,有1-N的数,输入N个区间,每次把区间内的数+1;
输出每个数的值;
 
/* 就是更新区间,查询单点 */
 
代码:
const int maxn = 1e5+7;
int c[maxn],n;

void add(int end,int val)
{
    while(end > 0)
    {
        c[end] += v;
        end -= lowbit(end);
    }
}

int sum(int x)
{
    int ret = 0;
    while(x <= n)
    {
        ret += c[x];
        x += lowbit(x);
    }
    return ret;
}

int main()
{
    while(scanf("%d",&n),n)
    {
        memset(c,0,sizeof(c));
        int l,r;
        for(int i=0;i<n;i++)
        {
            scanf("%d%d",&l,&r);
            add(r,1);
            add(l-1,-1);
        }
        for(int i=1;i<n;i++)
            printf("%d ",sum(i));
        printf("%d\n",sum(n));
    }
}

 

POJ 2352
输入N,输入N个点的x和y,输入按y1 == y2 ? X1 < x2 : y1 < y2的顺序;
输出对于每个点,在它左下方的点的个数;
 
/* 输入按照y的从小到大的顺序,那么后输入的数必然在前输入的数的右边或上面,因此可以将所有的点全都投影到X轴上,只需要统计在此之前输入的点中在当前点左边的点的个数*/
 
代码:
const int maxn = 32005;
int c[maxn],level[maxn],n;

void add(int pos)
{
    while(pos <= maxn)
    {
        c[pos]++;
        pos += lowbit(pos);
    }
}

int sum(int end)
{
    int ret = 0;
    while(end > 0)
    {
        ret += c[end];
        end -= lowbit(end);
    }
    return ret;
}

int main()
{
    scanf("%d",&n);
    int x,y;
    for(int i=0;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        level[sum(++x)]++;
        add(x);
    }
    for(int i=0;i<n;i++)
        printf("%d\n",level[i]);
}

 

POJ2155
对于一个N*N的零矩阵,输入N和操作次数T,对于每次操作,若输入为C x1 y1 x2 y2,表示将该子矩阵范围内的数反转,若输入为Q x1,y1,表示输出点(x1,y1)的值
 
/*一个二维的树状数组,每次将区间内的数+1,查询单点时输出数取mod2*/
 
代码:
const int maxn = 1007;
int c[maxn][maxn],n;

void update(int x,int y)
{
    for(int i=x;i<=n;i+=lowbit(i))
        for(int j=y;j<=n;j+=lowbit(j))
            c[i][j]++;
}

int sum(int x,int y)
{
    int ret = 0;
    for(int i=x;i>0;i-=lowbit(i))
        for(int j=y;j>0;j-=lowbit(j))
            ret += c[i][j];
    return ret;
}

int main()
{
    int t;
    scanf("%d",&t);
    while(t--)
    {
        memset(c,0,sizeof(c));
        int m,x1,x2,y1,y2;
        char C[2];
        scanf("%d%d",&n,&m);
        for(int i=0;i<m;i++)
        {
            scanf("%s",C);
            if(C[0] == 'C')
            {
                scanf("%d%d%d%d",&x1,&y1,&x2,&y2);
                update(++x2,++y2);
                update(x1,y1);
                update(x2,y1);
                update(x1,y2);
            }
            else
            {
                scanf("%d%d",&x1,&y1);
                printf("%d\n",sum(x1,y1)%2);
            }
        }
        printf("\n");
    }
}

 

POJ2299
输入N,输入N个数;
求这个序列的逆序数;
 
/*数据范围有1e9不可以直接保存在数组里,但是只要得到序列的相对大小,所以离散化后用树状数组求和*/
 
代码:
const int maxn = 5e5+7;
struct node{
    int pos,val;
}arr[maxn];
int n,c[maxn],reflect[maxn];

bool cmp(node a,node b)
{
    return a.val > b.val;
}

void update(int pos)
{
    while(pos <= n)
    {
        c[pos]++;
        pos += lowbit(pos);
    }
}

ll sum(int end)
{
    ll ret = 0;
    while(end > 0)
    {
        ret += c[end];
        end -= lowbit(end);
    }
    return ret;
}

int main()
{
    while(scanf("%d",&n),n)
    {
        ll ans = 0;
        memset(c,0,sizeof(c));
        for(int i=0;i<n;i++)
        {
            scanf("%d",&arr[i].val);
            arr[i].pos = i;
        }
        sort(arr,arr+n,cmp);
        for(int i=0;i<n;i++)
            reflect[arr[i].pos] = i;
        for(int i=0;i<n;i++)
        {
            ans += sum(++reflect[i]);
            update(reflect[i]);
        }
        cout << ans << endl;
    }
}

 

POJ3067

输入N,M,K,左边有1-N的点,右边有1-M的点,有K条线段连接左右的点,每次输入线段的左端点和右端点;

输出每两条线段的交点的总个数和;

 

/*排序之后用树状数组计算每个点前出现的比当前点小的点的个数,用i-当前个数加到总和里*/

 

代码:

const int maxn = 1e6+7;
struct node{
    int l,r;
}a[maxn];
int n,m,k,c[1007];

bool cmp(node x,node y)
{
    return x.l == y.l ? x.r < y.r : x.l < y.l;
}

void update(int pos)
{
    while(pos <= m)
    {
        c[pos] ++;
        pos += lowbit(pos);
    }
}

ll sum(int end)
{
    ll ret = 0;
    while(end > 0)
    {
        ret += c[end];
        end -= lowbit(end);
    }
    return ret;
}

int main()
{
    int T;
    scanf("%d",&T);
    for(int t=1;t<=T;t++)
    {
        ll ans = 0;
        memset(c,0,sizeof(c));
        scanf("%d%d%d",&n,&m,&k);
        for(int i=0;i<k;i++)
            scanf("%d%d",&a[i].l,&a[i].r);
        sort(a,a+k,cmp);
        for(int i=0;i<k;i++)
        {
            ans += i-sum(a[i].r);
            update(a[i].r);
        }
        printf("Test case %d: %lld\n",t,ans);
    }
}

 

POJ1195

 

1 X Y A 表示把(X,Y)的值+A,2 L B R T表示输出(L,B)(R,T)范围内的数的和;
 
/*二维的树状数组求和*/
 
代码:
const int maxn = 1100;
int c[maxn][maxn],s;

void update(int x,int y,int val)
{
    for(int i=x;i<=s;i+=lowbit(i))
        for(int j=y;j<=s;j+=lowbit(j))
            c[i][j] += val;
}

ll sum(int x,int y)
{
    ll ret = 0;
    for(int i=x;i>0;i-=lowbit(i))
        for(int j=y;j>0;j-=lowbit(j))
            ret += c[i][j];
    return ret;
}

int main()
{
    int n,x1,y1,x2,y2,a;
    while(scanf("%d",&n))
    {
        if(n == 3) return 0;
        else if(n == 0)
            scanf("%d",&s);
        else if(n == 1)
        {
            scanf("%d%d%d",&x1,&y1,&a);
            update(++x1,++y1,a);
        }
        else if(n == 2)
        {
            scanf("%d%d%d%d",&x1,&y1,&x2,&y2);
            ll ans = 0;
            ans += sum(++x2,++y2) + sum(x1,y1) - sum(x1,y2) - sum(x2,y1);
            printf("%lld\n",ans);
        }
    }
}

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!