@(简述树链剖分)
题目链接:luogu P3384 【模板】树链剖分
先上完整代码,变量名解释1
#include<cstdio> #include<algorithm> #include<iostream> using namespace std; typedef long long ll; #define N 500005 #define RI register int int tot=0,n,m,rt,md; int fa[ N ],deep[ N ],head[ N ],size[ N ],son[ N ],id[ N ],w[ N ],nw[ N ],top[ N ]; struct EDGE{ int to,next; }e[ N ]; inline void add( int from , int to ){ e[ ++ tot ].to = to; e[ tot ].next = head[ from ]; head[ from ] = tot; } template<class T> inline void read(T &res){ static char ch;T flag = 1; while( ( ch = getchar() ) < '0' || ch > '9' ) if( ch == '-' ) flag = -1; res = ch - 48; while( ( ch = getchar() ) >= '0' && ch <= '9' ) res = res * 10 + ch - 48; res *= flag; } struct NODE{ ll sum,flag; NODE *ls,*rs; NODE(){ sum = flag = 0; ls = rs = NULL; } inline void pushdown( int l , int r ) { if( flag ) { int midd = ( l + r ) >> 1; ls->flag += flag; rs->flag += flag; ls->sum += flag * ( midd - l + 1 ); rs->sum += flag * ( r - midd ); flag = 0; } } inline void update() { sum = ls->sum + rs->sum; } }tree[ N * 2 + 5 ],*p = tree,*root; NODE *build( int l , int r ) { NODE *nd = ++p; if( l == r ) { nd->sum = nw[ l ]; return nd; } int mid = ( l + r ) >> 1; nd->ls = build( l , mid ); nd->rs = build( mid + 1 , r ); nd->update(); return nd; } ll sum( int l , int r , int x , int y , NODE *nd ) { if( x <= l && r <= y ) { return nd->sum; } nd->pushdown( l , r ); int mid = ( l + r ) >> 1; ll res = 0; if( x <= mid ) res += sum( l , mid , x , y , nd->ls ); if( y >= mid + 1 ) res += sum( mid + 1 , r , x , y , nd->rs ); return res; } void modify( int l , int r , int x , int y , ll add , NODE *nd ) { if( x <= l && r <= y ) { nd->sum += ( r - l + 1 ) * add; nd->flag += add; return; } int mid = ( l + r ) >> 1; nd->pushdown( l , r ); if( x <= mid ) modify( l , mid , x , y , add , nd->ls ); if( y > mid ) modify( mid + 1 , r , x , y , add , nd->rs ); nd->update(); } void dfs1( int p ){ size[ p ] = 1; deep[ p ] = deep[ fa[ p ] ] + 1; for( int i = head[ p ] ; i ; i = e[ i ].next ){ int k = e[ i ].to; if( k == fa[ p ] ) continue; fa[ k ] = p; dfs1( k ); size[ p ] += size[ k ]; if( size[ son[ p ] ] < size[ k ] || !son[ p ] ) son[ p ] = k; } } void dfs2( int p , int tp ){ id[ p ] = ++tot; nw[ tot ] = w[ p ]; top[ p ] = tp; if( son[ p ] ) dfs2( son[ p ] , tp ); for( int i = head[ p ] ; i ; i = e[ i ].next ){ int k = e[ i ].to; if( k == fa[ p ] || k == son[ p ] ) continue; dfs2( k , k ); } } inline void ope1( int x , int y , ll add ){ while( top[ x ] != top[ y ] ){ if( deep[ top[ x ] ] < deep[ top[ y ] ] ) swap( x , y ); modify( 1 , n , id[ top[ x ] ] , id[ x ] , add , root ); x = fa[ top[ x ] ]; } if( deep[ x ] > deep[ y ] ) swap( x , y ); modify( 1 , n , id[ x ] , id[ y ] , add , root ); } inline ll ope2( int x , int y ){ ll res = 0; while( top[ x ] != top[ y ] ){ if( deep[ top[ x ] ] < deep[ top[ y ] ] ) swap( x , y ); res += sum( 1 , n , id[ top[ x ] ] , id[ x ] , root ); x = fa[ top[ x ] ]; } if( deep[ x ] > deep[ y ] ) swap( x , y ); res += sum( 1 , n , id[ x ] , id[ y ] , root ); return res; } inline void ope3( int x , int add ){ modify( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , add , root ); } inline ll ope4( int x ){ return sum( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , root ); } int main() { cin>>n>>m>>rt>>md; for( RI i = 1 ; i <= n ; i ++ ) read( w[ i ] ); for( RI i = 1 ; i <= n - 1 ; i ++ ){ int x,y; read( x ),read( y ); add( x , y ); add( y , x ); } dfs1( rt ),tot = 0; dfs2( rt , rt ); root = build( 1 , n ); for( RI i = 1 ; i <= m ; i ++ ){ int f; read( f ); switch( f ){ case 1:{ int x,y; ll add; read( x ),read( y ),read( add ); ope1( x , y , add ); break; } case 2:{ int x,y; read( x ),read( y ); printf( "%lld\n" , ope2( x , y ) % md ); break; } case 3:{ int x; ll add; read( x ),read( add ); ope3( x , add ); break; } case 4:{ int x; read( x ); printf( "%lld\n" , ope4( x ) % md ); break; } } } return 0; }
前置知识
请先能够熟练写出线段树并了解\(dfs\)序的性质
预处理
预处理分两次\(dfs\)
第一次处理出各个结点的深度,\(size\),重儿子,父亲。
第二次处理出重链,\(dfs\)序和每个点的\(top\)。
dfs1:
void dfs1( int p ){ size[ p ] = 1; deep[ p ] = deep[ fa[ p ] ] + 1; for( int i = head[ p ] ; i ; i = e[ i ].next ){ int k = e[ i ].to; if( k == fa[ p ] ) continue; fa[ k ] = p; dfs1( k ); size[ p ] += size[ k ]; if( size[ son[ p ] ] < size[ k ] || !son[ p ] ) son[ p ] = k; } }
dfs2:
void dfs2( int p , int tp ){ id[ p ] = ++tot;//每个点在dfs序里的位置 nw[ tot ] = w[ p ]; top[ p ] = tp; if( son[ p ] ) dfs2( son[ p ] , tp );//重链 for( int i = head[ p ] ; i ; i = e[ i ].next ){ int k = e[ i ].to; if( k == fa[ p ] || k == son[ p ] ) continue; dfs2( k , k );//轻链 } }
维护
为了更加高效的查询,我们选择用线段树来维护\(dfs\)序(树状数组等数据结构也可)。
没什么技术含量,直接套模板即可。
struct NODE{ ll sum,flag; NODE *ls,*rs; NODE(){ sum = flag = 0; ls = rs = NULL; } inline void pushdown( int l , int r ) { if( flag ) { int midd = ( l + r ) >> 1; ls->flag += flag; rs->flag += flag; ls->sum += flag * ( midd - l + 1 ); rs->sum += flag * ( r - midd ); flag = 0; } } inline void update() { sum = ls->sum + rs->sum; } }tree[ N * 2 + 5 ],*p = tree,*root; NODE *build( int l , int r ) { NODE *nd = ++p; if( l == r ) { nd->sum = nw[ l ]; return nd; } int mid = ( l + r ) >> 1; nd->ls = build( l , mid ); nd->rs = build( mid + 1 , r ); nd->update(); return nd; } ll sum( int l , int r , int x , int y , NODE *nd ) { if( x <= l && r <= y ) { return nd->sum; } nd->pushdown( l , r ); int mid = ( l + r ) >> 1; ll res = 0; if( x <= mid ) res += sum( l , mid , x , y , nd->ls ); if( y >= mid + 1 ) res += sum( mid + 1 , r , x , y , nd->rs ); return res; } void modify( int l , int r , int x , int y , ll add , NODE *nd ) { if( x <= l && r <= y ) { nd->sum += ( r - l + 1 ) * add; nd->flag += add; return; } int mid = ( l + r ) >> 1; nd->pushdown( l , r ); if( x <= mid ) modify( l , mid , x , y , add , nd->ls ); if( y > mid ) modify( mid + 1 , r , x , y , add , nd->rs ); nd->update(); }
查询
这是核心操作(敲黑板)。
子树有关操作
子树查询
由于\(dfs\)序的性质,以一个点为根的子树在\(dfs\)序中一定是连续的,所以我们只需要进行一次区间查询,需要查询的区间为:
[根结点在\(dfs\)序中的位置,根结点在\(dfs\)序中的位置+\(size\) - 1 ]
复杂度为\(O(logn)\)
代码如下:
inline ll ope4( int x ){ return sum( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , root ); }
子树修改
同理,进行一次区间修改
复杂度为\(O(logn)\)
代码如下:
inline void ope3( int x , int add ){ modify( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , add , root ); }
树链有关操作
这才是树剖的精髓所在啊!(战术后仰 )
这里主要会利用重链在\(dfs\)序中一定是连续的性质,一定要记住,否则你将无法理解接下来的操作
链查询
操作流程:
- 若两个点的top不同,则让top较深的点爬升到它的top的father,每次爬升进行一次区间查询2,把结果加到res上,直到top相等为止
- 此时两点的top为原来两点的LCA,且其中深度较浅的点就是LCA,再进行一次区间查询即可。
最坏时间复杂度\(O(log_{2}n)\)
代码如下:
inline ll ope2( int x , int y ){ ll res = 0; while( top[ x ] != top[ y ] ){ if( deep[ top[ x ] ] < deep[ top[ y ] ] )//把x调整为top深度更深的的点 swap( x , y ); res += sum( 1 , n , id[ top[ x ] ] , id[ x ] , root ); x = fa[ top[ x ] ]; } if( deep[ x ] > deep[ y ] ) swap( x , y ); res += sum( 1 , n , id[ x ] , id[ y ] , root ); return res; }
链修改
同理,爬升过程一模一样,只需要将链查询的区间查询改为区间修改即可。
最坏时间复杂度O(log^2^n)
代码如下:
inline void ope1( int x , int y , ll add ){ while( top[ x ] != top[ y ] ){ if( deep[ top[ x ] ] < deep[ top[ y ] ] ) swap( x , y ); modify( 1 , n , id[ top[ x ] ] , id[ x ] , add , root ); x = fa[ top[ x ] ]; } if( deep[ x ] > deep[ y ] ) swap( x , y ); modify( 1 , n , id[ x ] , id[ y ] , add , root ); }