LibreOJ#2359天天爱跑步

一个人想着一个人 提交于 2019-11-30 19:44:06

天天爱跑步

这或许现在不是\(NOIP\)最毒瘤的题了叭.
(当然你说是,我还可以肛你说\(NOIP\)没了)

嗯...一个很显然的暴力思路是:
对于每一个玩家,暴力跟着跑,走到\(w_i\)等于当前时间的点就统计.
这显然是对的...但它太慢了,完全跑不过去.

我们发现题目里给的条件其实是个这个:
设一条路径的起点终点分别是\(s_i,t_i\),长度是\(dis_{i,j}\),\(LCA\)\(lca_{i,j}\),某一点\(x\)\(deep_x\).
\[dis(s_i,t_i)-deep_{t_i}=w_p-deep_p\]

\[deep_s=deep_p+w_p\]
然后我们的问题就变成了对于每一个点,求它的子树内有多少点的权值和它相等,当然,权值有两种.
然后...这个很难做对叭.
我们想怎么去做它,因为我们发现这是不可避免的.
这时候大佬可能会说了,直接线段树维护桶,线段树合并即可.
但这不是我们要说的.(拒绝大力数据结构,偶尔写写还是挺开心的
我们的做法也是维护桶,但并非这么暴力,这种维护方法非常优美.
用桶去维护权值,下标就是权值,存储的值是有多少个.
我们发现,直接去维护这个全局桶是显然没法做的.
我们考虑,枚举观察点,去它的子树里统计.
由于观察点是所有点,所以我们只需要一遍\(dfs\)去枚举即可.
\(dfs\)由于其优美的特性,恰好能完成对子树的统计.
这样可以发现,统计答案部分的复杂度是\(\Theta(n)\)的,也就是说这题的复杂度瓶颈在于\(LCA\).
考虑如何统计:
对每一种权值,显然直接统计即可.
但当我们回溯的时候,我们发现,这时候,以当前点\(x\)\(LCA\)的所有点都已经没有了贡献,所以我们需要把这一部分的贡献减去.
于是我们需要维护一个链表,维护以\(x\)\(LCA\)的路径信息.
每次回溯的时候,遍历这个链表,把所有的以当前点为\(LCA\)的路径的贡献减去即可.
那么加入贡献呢?按照我们刚才的式子统计即可.
这里也要维护一个链表,维护以当前点为终点的路径信息,每次走到一个点就把以它为终点的贡献都加上.同时把以它为起点的贡献也都加上.
还有一点,我们发现统计答案的时候,并不是直接加上桶中的元素即可.
因为这样很显然会算到不该对它产生贡献的点.
所以一个显然的思路是:
统计答案时相加的应该是进去这个点前的桶和回溯出来时的差.
这时显然桶里统计到的贡献都是合法的位于它的子树内的.
\(Code:\)

#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <string>
#include <vector>
#include <queue>
#include <cmath>
#include <ctime>
#include <map>
#include <set>
#define MEM(x,y) memset ( x , y , sizeof ( x ) )
#define rep(i,a,b) for (int i = (a) ; i <= (b) ; ++ i)
#define per(i,a,b) for (int i = (a) ; i >= (b) ; -- i)
#define pii pair < int , int >
#define one first
#define two second
#define rint read<int>
#define int long long
#define pb push_back
#define db double

using std::queue ;
using std::set ;
using std::pair ;
using std::max ;
using std::min ;
using std::priority_queue ;
using std::vector ;
using std::swap ;
using std::sort ;
using std::unique ;
using std::greater ;

template < class T >
    inline T read () {
        T x = 0 , f = 1 ; char ch = getchar () ;
        while ( ch < '0' || ch > '9' ) {
            if ( ch == '-' ) f = - 1 ;
            ch = getchar () ;
        }
       while ( ch >= '0' && ch <= '9' ) {
            x = ( x << 3 ) + ( x << 1 ) + ( ch - 48 ) ;
            ch = getchar () ;
       }
       return f * x ;
}

const int N = 300005 ;
int n , m , tot , head[N] , f[25][N] , t[N] ;
int deep[N] , w[N] , ans[N] , val[N] , s[N] ;
int cnt1[N<<1] , cnt2[N<<1] , dis[N] , ss[N] ;
vector < int > G[N] ;

inline void dfs (int cur , int anc , int dep) {
    f[0][cur] = anc ; deep[cur] = dep ;
    for (int i = 1 ; ( 1 << i ) <= dep ; ++ i)
        f[i][cur] = f[i-1][f[i-1][cur]] ;
    for (int k : G[cur] ) {
        if ( k == anc ) continue ;
        dfs ( k , cur , dep + 1 ) ;
    }
    return ;
}

inline int LCA (int x , int y) {
    if ( deep[x] < deep[y] ) swap ( x , y ) ;
    int k = log2 ( deep[x] ) ;
    for (int i = k ; i >= 0 ; -- i)
        if ( deep[f[i][x]] >= deep[y] ) 
            x = f[i][x] ;
    if ( x == y ) return x ;
    for (int i = k ; i >= 0 ; -- i)
        if ( f[i][x] != f[i][y] )
            x = f[i][x] , y = f[i][y] ;
    return f[0][x] ;
}

int head1[N] , head2[N] , tot1 , tot2 ;
struct mylist { int to , next ; } l1[N] , l2[N] ;

inline void insert1 (int x , int y) {
    l1[++tot1].next = head1[x] ;
    l1[tot1].to = y ; head1[x] = tot1 ;
    return ;
}

inline void insert2 (int x , int y) {
    l2[++tot2].next = head2[x] ;
    l2[tot2].to = y ; head2[x] = tot2 ;
    return ;
}

inline void solve (int cur) {
    int tmp = cnt1[w[cur]+deep[cur]] , _tmp = cnt2[w[cur]-deep[cur]+N] ;
    for (int k : G[cur] ) { if ( k == f[0][cur] ) continue ; solve ( k ) ; }
    cnt1[deep[cur]] += ss[cur] ;
    for (int i = head1[cur] ; i ; i = l1[i].next) {
        int k = l1[i].to ;
        ++ cnt2[dis[k]-deep[t[k]]+N] ;
    }
    ans[cur] += ( ( cnt1[w[cur]+deep[cur]] - tmp ) + ( cnt2[w[cur]-deep[cur]+N] - _tmp ) ) ;
    for (int i = head2[cur] ; i ; i = l2[i].next) {
        int k = l2[i].to ;
        -- cnt1[deep[s[k]]] ;
        -- cnt2[dis[k]-deep[t[k]]+N] ;
    }
    return ;
}

signed main (int argc , char * argv[]) {
    n = rint () ; m = rint () ;
    rep ( i , 2 , n ) {
        int u = rint () , v = rint () ;
        G[u].pb ( v ) ; G[v].pb ( u ) ;
    }
    dfs ( 1 , 0 , 1 ) ; f[0][1] = 1 ;
    rep ( i , 1 , n ) w[i] = rint () ;
    rep ( i , 1 , m ) {
        s[i] = rint () ; t[i] = rint () ;
        ++ ss[s[i]] ; int tmp = LCA ( s[i] , t[i] ) ;
        dis[i] = deep[s[i]] + deep[t[i]] - deep[tmp] * 2 ;
        insert1 ( t[i] , i ) ; insert2 ( tmp , i ) ;
        if ( deep[tmp] + w[tmp] == deep[s[i]] ) -- ans[tmp] ;
    }
    solve ( 1 ) ;
    rep ( i , 1 , n ) printf ("%lld " , ans[i] ) ;
    return 0 ;
}
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!