学习过线段树之后,应该会觉得线段树在各种维护问题上代码量比较大,而且比较麻烦。主要原因就是因为线段树把每个大区间都分成两个小区间,直到分成单独点。但是在实际操作的时候,很多申请的区间节点都是用不上的,造成了空间的浪费,那么如何解决这一问题呢?
先引入前缀和的概念:
前缀和:对于某一数组a[n],其中前缀和数组s[n]定义为s0=0,si=a[1]+…+a[i](1<=i<=n)。即a数组的前i(1<=i<=n)项和叫做该数组的前缀和。容易知道,a数组中的任意区间和都可以通过该前缀树组中的元素相减得到
有了前缀和,我们可以简化下列线段树:
怎么简化呢?我们把该线段树上的所有的右子树去掉,只保留所有的左子树,这样之后,右子树本该保留的值就拿父节点保留的值减去左子树保留的值可以得到
如上图,蓝色部分是线段树删去的保存的叶子节点,箭头表示它保存在上面的父区间中。现在的结构就叫做树状数组,下图是更直观的树状数组图:
我们可能会想,为什么要去掉右子树,去掉所有的左子树不行吗?在解释这个问题之前,我们先看下面这个函数
Lowbit()函数
Lowbit(x):将x在二进制分解下最低位的大小,通俗来讲就是找到从右向左第一个1所在位次k,其对应的大小为2k
例如Lowbit(9),9的2进制为1001,显然最低位的大小为20=1
下面给出一个结论来求Lowbit(x),已知x的二进制表示,我们把其二进制按位取反,接着再加1。应该都想到了——求x的负数的二进制补码。然后有个很神奇的操作,就是我们把x和-x的二进制按位与,除了最低位,其余每一位一定都是相反的,也就是与之后结果为0,这样得到的结果就是最低位的大小了
Lowbit(x) = x&(-x);
仍然以上面的9做例子:
1001
& 0111
--------
0001
为了简化代码,以后的Lowbit()函数直接一行代码搞定:
#define lowbit(x) (x&(-x))
树状数组
对于上述树状数组t,我们可以看出:
t[1] = a[1];
t[2] = a[1] + a[2];
t[3] = a[3];
t[4] = a[1] + a[2] + a[3] + a[4];
t[5] = a[5];
t[6] = a[5] + a[6];
t[7] = a[7];
t[8] = a[1] + a[2] + a[3] + a[4] + a[5] + a[6] + a[7] + a[8];
然后我们可以发现如下规律:
性质一:t[i]保存的区间长度(子节点个数)为lowbit(i)
性质二:t[i] = a[i-2k+1] + a[i-2k+2] + … + a[i]。2k为lowbit(i)
性质三:除根节点外,每个子节点t(i)的父节点是t(i+lowbit(i) )
性质四:每个树状数组t[i]保存的是区间[i-lowbit(i)+1,i]
性质五:前缀和si = t[i] + t[i-2k1] + t[(i - 2k1) - 2k2] + … +t[0]。其中2k1=lowbit(i),2k2=lowbit(i-2k1),…直到2kn=i,此外i也在一直更新
性质六:树的深度为O(logn)
有了上面的规律,我们很明显知道为什么刚刚我们去掉所有的右子树而不是左子树了。利用lowbit函数我们把树状数组和二进制联系起来,这样的话维护树状数组更加得高效简便。树状数组原理上是用来维护前缀和的,但是引入差分之后也能维护区间
初始化
所谓初始化,也就是输入a数组后更新t数组。下面介绍两种更新方法:
方法一:由上述性质二,当输入完a数组后我们还要用一个for循环去初始化t数组,时间复杂度为O(n),因此这个方法略麻烦
void init(int i){
int k=i;
i=k-lowbit(k)+1;
for(int j=1;i<=k;i=k-lowbit(k)+j){
t[k]+=a[i]; j++;
}
}
方法二:由上述性质三,假设我们已知某叶子节点a[i],我们可以一直往上追溯其父节点直到到达根节点,到达的条件是i等于区间长度n,时间复杂度为O(logn),建议采用这种方法
void update(int i,int k){
while(i<=n){
t[i]+=k;
i+=lowbit(i);
}
}
下面展示两种初始化的具体区别:
#define lowbit(x) (x&(-x))
const int N=1e5+10;
int a[N],t[N];
int n;
//方法一
void init(int i){
int k=i;
i=k-lowbit(k)+1;
for(int j=1;i<=k;i=k-lowbit(k)+j){
t[k]+=a[i]; j++;
}
}
//方法二
void update(int i,int k){
while(i<=n){
t[i]+=k;
i+=lowbit(i);
}
}
int main()
{
cin>>n;
for(int i=1;i<=n;i++){
cin>>a[i];
update(i,a[i]);//方法二
}
for(int i=1;i<=n;i++) init(i); //方法一
//for(int i=1;i<=n;i++) cout<<t[i]<<" ";
return 0;
}
单点修改+区间查询
前缀和
所谓区间查询,也就是查询出区间的两个前缀和再相减即可
由性质五,我们只要一直更新i=i-lowbit(i)直到i=0,如下图所示,就能不断到达当前区间之后的最大区间
ll getSum(int i){
ll ans=0;
for(;i;i-=lowbit(i))
ans+=t[i];
return ans;
}
如果我们要查询区间[x,y]的和,就getSum(y)-getSum(x-1)即可
单点修改
容易发现我们上面的初始化方法二就是使用单点修改的思想
void update(int i,int k){ //第i个节点加上k
while(i<=n){
t[i]+=k;
i+=lowbit(i);
}
}
区间更新+单点查询
首先我们引入差分的概念:
差分即相邻两个数的差,由a数组我们能得到a的差分数组d[i]=a[i]-a[i-1],还可以得到二者之间的关系:
a[i]=d[1]+…+d[i]
那么我们会发现,如果对一个区间[x,y]内的所有数都执行加法,那么显然只有d[x]和d[y+1]的值会改变,[x+1,y]区间的值都不变
因此我们用d数组代替上面的t数组维护树状数组,当我们进行区间加法时,很明显只用上面的update函数更新d[x]和d[y+1],即d[x]+k,d[y+1]-k
因此区间更新为:
#define lowbit(x) (x&(-x))
const int N=?;
int a[N],d[N];
int n;
void update(int i,int k){ //初始化和区间更新的函数
while(i<=n){
d[i]+=k;
i+=lowbit(i);
}
}
ll ask(int i){ //求a[i]
ll ans=0;
for(;i;i-=lowbit(i))
ans+=d[i];
return ans;
}
int main(){
memset(d,0,sizeof(d)); //多样例输入不要忘记清空d数组
//初始化
for(int i=1;i<=){
cin>>a[i];
update(i,a[i]-a[i-1]);
}
//执行[x,y]区间加k
update(x,k);
update(y+1,-k);
ask(x); //求a[x]
}
区间更新+区间查询
当我们使用差分构造树状数组后,区间查询即求前缀和相减
由a[i]=d[1]+…+d[i],得
前缀和=∑a[i]=a[1]+…+a[n]
=(d[1])+(d[1]+d[2])+…+(d[1]+d[2]+…+d[n])
=n*d[1]+(n-1)*d[2]+…+2*d[n-1]+d[n]
=(n+1)*(d[1]+d[2]+…+d[n])-(d[1]+2*d[2]+…+n*d[n])
=(n+1)*∑d[i] - ∑i*d[i]
于是我们需要维护两个树状数组,分别是d[i]和c[i]=i*d[i]
初始化及区间更新函数:
下面的先用t保存i是因为树状数组结构,每个点的上面有很多父区间,我们将父区间更新时也传入刚开始的i*d[i],因此需要暂时保存i
//区间更新的话只需要更新两个点
void update(int i,int k){
int t=i; //由于i下面变化但是外面乘的i不变
while(i<=n){
d[i]+=k;
c[i]+=t*k;
i+=lowbit(i);
}
}
求前缀和函数:
ll getSum(int i){
ll ans=0;
int t=i+1;
for(;i;i-=lowbit(i)){
ans+=t*d[i]-c[i];
}
return ans;
}
代码示例:
//如果数据很大,就把两个函数传入的变量都设置为long long
typedef long long ll;
#define lowbit(x) (x&(-x))
const int N=1e5+10;
ll a[N],d[N],c[N];
ll n;
void update(int i,int k){
int t=i;
while(i<=n){
d[i]+=k;
c[i]+=t*k;
i+=lowbit(i);
}
}
ll getSum(int i){ //求前缀和
ll ans=0;
int t=i+1;
for(;i;i-=lowbit(i)){
ans+=t*d[i]-c[i];
}
return ans;
}
int main(){
//初始化
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
update(i,a[i]-a[i-1]);
}
//[x.y]区间每个数加上k
update(x,k);
update(y+1,-k);
//求[x,y]区间和
printf("%lld\n",getSum(y)-getSum(x-1));
return 0;
}
来源:CSDN
作者:Happig丶
链接:https://blog.csdn.net/qq_44691917/article/details/104134671