题目
点这里看题目。
分析
不难发现,设两人取得的下标集合为\(S_a\)和\(S_b\),那么符合要求的下标集合对需要满足\(S_a\)和\(S_b\)对应的值全部异或起来为 0 。
因此,我们可以考虑异或为\(0\)的下标集合\(S\),它对答案的贡献就是\(2^{|S|}\)。
根据这个思想,我们可以考虑如下的 DP :
\(f(i,j)\):前\(i\)个数,异或为\(j\)的集合的个数。转移如下:
这样当然是没有办法做的。不过我们可以考虑将这样的转移写成生成函数的形式:
其中的\(\bigoplus\)运算符为异或卷积,可以理解为系数相乘,指数异或。
说着异或卷积,貌似就可以 FWT ?
显然不行,这样时间有\(O(n^2\log_2n)\),会 T 的,我们需要继续挖掘性质。
我们想想 FWT 之后的序列的性质:
由于\(A\)中只有一个 1 和一个 2 ,那么我们我们 FWT 后每一项只会是 -1 (1-2) 或者 3 (1+2)。
因此我们\(FWT(F)_i\)一定可以表示为\((-1)^{p_i}\times 3^{q_i}\),其中\(p_i\)表示\(FWT(A(1))\sim FWT(A(n))\)中第\(i\)项上\(-1\)的个数,\(q_i\)同理。
设\(s_i\)为\(FWT(A(1))\sim FWT(A(n))\)中第\(i\)项的和。我们可以得到:
假如我们快速求出\(s_i\),我们就可以解出\(p_i\)和\(q_i\),进而算出\(FWT(F)\)和\(F\)。
这里用到了一个性质: FWT 的和等于和的 FWT 。
设\(B_i=\sum_{j=1}^n A(j)_i\),这个性质即:
这其实比较好理解,因为合起来 FWT 的时候来自不同\(A\)的值是可以看成互不影响的。
然后就可以\(O(n\log_2n)\)解决了。
代码
#include <cstdio> const int mod = 998244353, inv2 = 499122177, inv4 = 748683265; const int MAXN = 4e6 + 5; template<typename _T> void read( _T &x ) { x = 0;char s = getchar();int f = 1; while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();} while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();} x *= f; } template<typename _T> void write( _T x ) { if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; } if( 9 < x ){ write( x / 10 ); } putchar( x % 10 + '0' ); } template<typename _T> _T MAX( const _T a, const _T b ) { return a > b ? a : b; } int F[MAXN]; int N, len; int fix( const int x ) { return ( x % mod + mod ) % mod; } int qkpow( int base, int indx ) { int ret = 1; while( indx ) { if( indx & 1 ) ret = 1ll * ret * base % mod; base = 1ll * base * base % mod, indx >>= 1; } return ret; } void FWT( int *f, const int mode ) { int t1, t2; for( int s = 2 ; s <= len ; s <<= 1 ) for( int i = 0, t = s >> 1 ; i < len ; i += s ) for( int j = i ; j < i + t ; j ++ ) { t1 = f[j], t2 = f[j + t]; if( mode > 0 ) f[j] = ( t1 + t2 ) % mod, f[j + t] = fix( t1 - t2 ); else f[j] = 1ll * ( t1 + t2 ) * inv2 % mod, f[j + t] = 1ll * fix( t1 - t2 ) * inv2 % mod; } } signed main() { int mx = 0; read( N ); for( int i = 1, a ; i <= N ; i ++ ) read( a ), mx = MAX( mx, a ), F[a] += 2; for( len = 1 ; len <= mx ; len <<= 1 ); F[0] += N, FWT( F, 1 ); //这里千万不能写成赋值!a可以为0! for( int i = 0 ; i < len ; i ++ ) { int cnt3 = 1ll * ( F[i] + N ) * inv4 % mod; int cnt1 = fix( N - cnt3 ); F[i] = qkpow( 3, cnt3 ); if( cnt1 & 1 ) F[i] = mod - F[i]; } FWT( F, -1 ); write( fix( F[0] - 1 ) ), putchar( '\n' ); return 0; }
来源:https://www.cnblogs.com/crashed/p/12593830.html