差分+简单数学即可.
首先有个性质:
两条链相交等价于其中一条链的\(LCA\)在另一条链上.
于是我们就对每一条链的\(LCA\)都加\(1\).
最后查询每一条链的区间和即可.树剖实现.
但这样我们会算重复,就是说\((a,b)\)两条链相交我们会算\((a,b)\)一次,\((b,a)\)一次.
也就是说我们算出的是有序数对.容斥掉即可.(没有公式,直接减掉一半即可.)
\(Code:\)
#include <algorithm> #include <iostream> #include <cstdlib> #include <cstring> #include <cstdio> #include <string> #include <vector> #include <queue> #include <cmath> #include <ctime> #include <map> #include <set> #define MEM(x,y) memset ( x , y , sizeof ( x ) ) #define rep(i,a,b) for (int i = (a) ; i <= (b) ; ++ i) #define per(i,a,b) for (int i = (a) ; i >= (b) ; -- i) #define pii pair < int , int > #define X first #define Y second #define rint read<int> #define int long long #define pb push_back #define ls ( rt << 1 ) #define rs ( rt << 1 | 1 ) #define mid ( ( l + r ) >> 1 ) using std::queue ; using std::set ; using std::pair ; using std::max ; using std::min ; using std::priority_queue ; using std::vector ; using std::swap ; using std::sort ; using std::unique ; using std::greater ; template < class T > inline T read () { T x = 0 , f = 1 ; char ch = getchar () ; while ( ch < '0' || ch > '9' ) { if ( ch == '-' ) f = - 1 ; ch = getchar () ; } while ( ch >= '0' && ch <= '9' ) { x = ( x << 3 ) + ( x << 1 ) + ( ch - 48 ) ; ch = getchar () ; } return f * x ; } const int N = 1e6 + 100 ; vector < int > G[N] ; int f[N] , deep[N] , ans , idx[N] , cnt ; int n , m , p[N][2] , siz[N] , son[N] , top[N] ; struct seg { int left , right , data , tag ; inline int size () { return right - left + 1 ; } } t[N<<2] ; inline void dfs (int cur , int anc , int dep) { f[cur] = anc ; deep[cur] = dep ; siz[cur] = 1 ; int maxson = - 1 ; for (int k : G[cur]) { if ( k == anc ) continue ; dfs ( k , cur , dep + 1 ) ; siz[cur] += siz[k] ; if ( siz[k] > maxson ) maxson = siz[k] , son[cur] = k ; } return ; } inline void _dfs (int cur , int topf) { top[cur] = topf ; idx[cur] = ++ cnt ; if ( ! son[cur] ) return ; _dfs ( son[cur] , topf ) ; for (int k : G[cur]) { if ( k == son[cur] || k == f[cur] ) continue ; _dfs ( k , k ) ; } return ; } inline void pushup (int rt) { t[rt].data = t[ls].data + t[rs].data ; return ; } inline void build (int rt , int l , int r) { t[rt].left = l ; t[rt].right = r ; t[rt].tag = 0 ; if ( l == r ) { t[rt].data = 0 ; return ; } build ( ls , l , mid ) ; build ( rs , mid + 1 , r ) ; pushup ( rt ) ; return ; } inline void pushdown (int rt) { t[ls].tag += t[rt].tag ; t[rs].tag += t[rt].tag ; t[ls].data += t[ls].size () * t[rt].tag ; t[rs].data += t[rs].size () * t[rt].tag ; t[rt].tag = 0 ; return ; } inline void update (int rt , int ll , int rr , int val) { int l = t[rt].left , r = t[rt].right ; if ( l == ll && r == rr ) { t[rt].tag += val ; t[rt].data += val ; return ; } if ( t[rt].tag ) pushdown ( rt ) ; if ( rr <= mid ) update ( ls , ll , rr , val ) ; else if ( ll > mid ) update ( rs , ll , rr , val ) ; else { update ( ls , ll , mid , val ) ; update ( rs , mid + 1 , rr , val ) ; } pushup ( rt ) ; return ; } inline int query (int rt , int ll , int rr) { int l = t[rt].left , r = t[rt].right ; if ( ll == l && r == rr ) return t[rt].data ; if ( t[rt].tag ) pushdown ( rt ) ; if ( rr <= mid ) return query ( ls , ll , rr ) ; else if ( ll > mid ) return query ( rs , ll , rr ) ; else return query ( ls , ll , mid ) + query ( rs , mid + 1 , rr ) ; } inline int qrange (int x , int y) { int res = 0 ; while ( top[x] != top[y] ) { if ( deep[top[x]] < deep[top[y]] ) swap ( x , y ) ; res += query ( 1 , idx[top[x]] , idx[x] ) ; x = f[top[x]] ; } if ( deep[x] > deep[y] ) swap ( x , y ) ; return res + query ( 1 , idx[x] , idx[y] ) ; } inline int LCA (int x , int y) { while ( top[x] != top[y] ) deep[top[x]] < deep[top[y]] ? y = f[top[y]] : x = f[top[x]] ; return deep[x] < deep[y] ? x : y ; } signed main (int argc , char * argv[]) { n = rint () ; m = rint () ; rep ( i , 2 , n ) { int u = rint () , v = rint () ; G[u].pb ( v ) ; G[v].pb ( u ) ; } dfs ( 1 , 0 , 1 ) ; _dfs ( 1 , 1 ) ; build ( 1 , 1 , cnt ) ; rep ( i , 1 , m ) { p[i][0] = rint () ; p[i][1] = rint () ; int t = LCA ( p[i][0] , p[i][1] ) ; update ( 1 , idx[t] , idx[t] , 1 ) ; } rep ( i , 1 , m ) ans += ( qrange ( p[i][0] , p[i][1] ) - 1 ) ; rep ( i , 1 , n ) { int tmp = query ( 1 , idx[i] , idx[i] ) ; ans -= tmp * ( tmp - 1 ) / 2 ; } printf ("%lld\n" , ans ) ; return 0 ; }