天天爱跑步
这或许现在不是\(NOIP\)最毒瘤的题了叭.
(当然你说是,我还可以肛你说\(NOIP\)没了)
嗯...一个很显然的暴力思路是:
对于每一个玩家,暴力跟着跑,走到\(w_i\)等于当前时间的点就统计.
这显然是对的...但它太慢了,完全跑不过去.
我们发现题目里给的条件其实是个这个:
设一条路径的起点终点分别是\(s_i,t_i\),长度是\(dis_{i,j}\),\(LCA\)是\(lca_{i,j}\),某一点\(x\)是\(deep_x\).
\[dis(s_i,t_i)-deep_{t_i}=w_p-deep_p\]
和
\[deep_s=deep_p+w_p\]
然后我们的问题就变成了对于每一个点,求它的子树内有多少点的权值和它相等,当然,权值有两种.
然后...这个很难做对叭.
我们想怎么去做它,因为我们发现这是不可避免的.
这时候大佬可能会说了,直接线段树维护桶,线段树合并即可.
但这不是我们要说的.(拒绝大力数据结构,偶尔写写还是挺开心的
我们的做法也是维护桶,但并非这么暴力,这种维护方法非常优美.
用桶去维护权值,下标就是权值,存储的值是有多少个.
我们发现,直接去维护这个全局桶是显然没法做的.
我们考虑,枚举观察点,去它的子树里统计.
由于观察点是所有点,所以我们只需要一遍\(dfs\)去枚举即可.
而\(dfs\)由于其优美的特性,恰好能完成对子树的统计.
这样可以发现,统计答案部分的复杂度是\(\Theta(n)\)的,也就是说这题的复杂度瓶颈在于\(LCA\).
考虑如何统计:
对每一种权值,显然直接统计即可.
但当我们回溯的时候,我们发现,这时候,以当前点\(x\)为\(LCA\)的所有点都已经没有了贡献,所以我们需要把这一部分的贡献减去.
于是我们需要维护一个链表,维护以\(x\)为\(LCA\)的路径信息.
每次回溯的时候,遍历这个链表,把所有的以当前点为\(LCA\)的路径的贡献减去即可.
那么加入贡献呢?按照我们刚才的式子统计即可.
这里也要维护一个链表,维护以当前点为终点的路径信息,每次走到一个点就把以它为终点的贡献都加上.同时把以它为起点的贡献也都加上.
还有一点,我们发现统计答案的时候,并不是直接加上桶中的元素即可.
因为这样很显然会算到不该对它产生贡献的点.
所以一个显然的思路是:
统计答案时相加的应该是进去这个点前的桶和回溯出来时的差.
这时显然桶里统计到的贡献都是合法的位于它的子树内的.
\(Code:\)
#include <algorithm> #include <iostream> #include <cstdlib> #include <cstring> #include <cstdio> #include <string> #include <vector> #include <queue> #include <cmath> #include <ctime> #include <map> #include <set> #define MEM(x,y) memset ( x , y , sizeof ( x ) ) #define rep(i,a,b) for (int i = (a) ; i <= (b) ; ++ i) #define per(i,a,b) for (int i = (a) ; i >= (b) ; -- i) #define pii pair < int , int > #define one first #define two second #define rint read<int> #define int long long #define pb push_back #define db double using std::queue ; using std::set ; using std::pair ; using std::max ; using std::min ; using std::priority_queue ; using std::vector ; using std::swap ; using std::sort ; using std::unique ; using std::greater ; template < class T > inline T read () { T x = 0 , f = 1 ; char ch = getchar () ; while ( ch < '0' || ch > '9' ) { if ( ch == '-' ) f = - 1 ; ch = getchar () ; } while ( ch >= '0' && ch <= '9' ) { x = ( x << 3 ) + ( x << 1 ) + ( ch - 48 ) ; ch = getchar () ; } return f * x ; } const int N = 300005 ; int n , m , tot , head[N] , f[25][N] , t[N] ; int deep[N] , w[N] , ans[N] , val[N] , s[N] ; int cnt1[N<<1] , cnt2[N<<1] , dis[N] , ss[N] ; vector < int > G[N] ; inline void dfs (int cur , int anc , int dep) { f[0][cur] = anc ; deep[cur] = dep ; for (int i = 1 ; ( 1 << i ) <= dep ; ++ i) f[i][cur] = f[i-1][f[i-1][cur]] ; for (int k : G[cur] ) { if ( k == anc ) continue ; dfs ( k , cur , dep + 1 ) ; } return ; } inline int LCA (int x , int y) { if ( deep[x] < deep[y] ) swap ( x , y ) ; int k = log2 ( deep[x] ) ; for (int i = k ; i >= 0 ; -- i) if ( deep[f[i][x]] >= deep[y] ) x = f[i][x] ; if ( x == y ) return x ; for (int i = k ; i >= 0 ; -- i) if ( f[i][x] != f[i][y] ) x = f[i][x] , y = f[i][y] ; return f[0][x] ; } int head1[N] , head2[N] , tot1 , tot2 ; struct mylist { int to , next ; } l1[N] , l2[N] ; inline void insert1 (int x , int y) { l1[++tot1].next = head1[x] ; l1[tot1].to = y ; head1[x] = tot1 ; return ; } inline void insert2 (int x , int y) { l2[++tot2].next = head2[x] ; l2[tot2].to = y ; head2[x] = tot2 ; return ; } inline void solve (int cur) { int tmp = cnt1[w[cur]+deep[cur]] , _tmp = cnt2[w[cur]-deep[cur]+N] ; for (int k : G[cur] ) { if ( k == f[0][cur] ) continue ; solve ( k ) ; } cnt1[deep[cur]] += ss[cur] ; for (int i = head1[cur] ; i ; i = l1[i].next) { int k = l1[i].to ; ++ cnt2[dis[k]-deep[t[k]]+N] ; } ans[cur] += ( ( cnt1[w[cur]+deep[cur]] - tmp ) + ( cnt2[w[cur]-deep[cur]+N] - _tmp ) ) ; for (int i = head2[cur] ; i ; i = l2[i].next) { int k = l2[i].to ; -- cnt1[deep[s[k]]] ; -- cnt2[dis[k]-deep[t[k]]+N] ; } return ; } signed main (int argc , char * argv[]) { n = rint () ; m = rint () ; rep ( i , 2 , n ) { int u = rint () , v = rint () ; G[u].pb ( v ) ; G[v].pb ( u ) ; } dfs ( 1 , 0 , 1 ) ; f[0][1] = 1 ; rep ( i , 1 , n ) w[i] = rint () ; rep ( i , 1 , m ) { s[i] = rint () ; t[i] = rint () ; ++ ss[s[i]] ; int tmp = LCA ( s[i] , t[i] ) ; dis[i] = deep[s[i]] + deep[t[i]] - deep[tmp] * 2 ; insert1 ( t[i] , i ) ; insert2 ( tmp , i ) ; if ( deep[tmp] + w[tmp] == deep[s[i]] ) -- ans[tmp] ; } solve ( 1 ) ; rep ( i , 1 , n ) printf ("%lld " , ans[i] ) ; return 0 ; }