题目描述
将\(n\times n\)的网格黑白染色,使得不存在任意一行、任意一列、任意一条大对角线的所有格子同色,求方案数对\(998244353\)取模的结果。
输入
一行一个整数\(n\)。
输出
一行一个整数表示答案对\(998244353\)取模的值。
样例
样例输入
3
样例输出
32
数据范围
对于\(100\%\)的数据,\(1\leq n\leq 300\)。
比第一题难了不知道多少……
这种东西怎么看都是容斥嘛。
我们先考虑对角线没有限制的情况:
枚举行和列有多少个是同色的,若行+列是奇数,则减去方案数,若行+列是偶数,则加上方案数,其他没有限制的点任意选,容斥一波即可(注意,若既有行又有列,则只能是同一颜色;但如果只有行或只有列则可以随意指定颜色,方案数应随之变动)。
那么有对角线该怎么办?
仍然容斥,\(0\)条对角线-\(1\)条对角线+\(2\)条对角线即是最终答案。其中\(0\)条对角线就是上述的算法。
考虑\(1\)条对角线的情况,不妨设是主对角线,令\(f_{i,j,k}\)表示考虑到前\(i\)行\(i\)列,选中了\(j\)行\(k\)列的方案数。然后考虑转移到\(i+1\),则枚举选\(0/1\)行,\(0/1\)列,分别转移。但是会存在一个问题,如果既没有选择行又没有选择列,会导致判断\((i,i)\)这个格子可以任选,但事实上主对角线被锁死了,所以要乘上\(\frac{1}{2}\)的系数。最后\(DP\)完后再根据\(j+k\)的奇偶性确定容斥系数,然后再乘上\(2^{(n-j)(n-k)}\)的系数表示未被考虑的格子的方案。
然后考虑\(2\)条对角线的情况,类似\(1\)条对角线,只是变成从中心开始一圈一圈往外\(DP\),讨论的情况更多(并且要特判,奇数的时候两条对角线颜色必须一样,偶数的时候两条对角线颜色不必一样),就没有什么差别了。
\(Code:\)
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; #define ll long long #define mod 998244353 #define inv2 499122177 #define inv16 935854081 int n, f[2][305][305]; int C[305][305], mul[100005]; void Add(int &a, int b){a = (a + b) % mod;} int Solve0() { int ans = 0; for (int i = 0; i <= n; i++) for (int j = 0; j <= n; j++) { int k; if (!i || !j) k = i ^ j; else k = 1; int x = (n - i) * (n - j) + k; ans = (ans + (ll)C[n][i] * C[n][j] % mod * mul[x] % mod * (1 - (i + j) % 2 * 2) % mod) % mod; } return ans; } int Solve1() { memset(f, 0, sizeof(f)); int z = 0; f[0][0][0] = 2; for (int i = 0; i < n; i++) { for (int j = 0; j <= i; j++) { for (int k = 0; k <= i; k++) { Add(f[z ^ 1][j][k], (ll)f[z][j][k] * inv2 % mod); Add(f[z ^ 1][j][k + 1], -f[z][j][k]); Add(f[z ^ 1][j + 1][k], -f[z][j][k]); Add(f[z ^ 1][j + 1][k + 1], f[z][j][k]); f[z][j][k] = 0; } } z ^= 1; } int ans = 0; for (int i = 0; i <= n; i++) for (int j = 0; j <= n; j++) Add(ans, (ll)f[z][i][j] * mul[(n - i) * (n - j)] % mod); return ans; } int Solve2() { memset(f, 0, sizeof(f)); int z = 0; if (n & 1) { f[0][0][0] = 1; f[0][0][1] = f[0][1][0] = -2; f[0][1][1] = 2; } else f[0][0][0] = 2; for (int k = n & 1; k < n; k += 2) { for (int i = 0; i <= k; i++) { for (int j = 0; j <= k; j++) { Add(f[z ^ 1][i][j], (ll)f[z][i][j] * inv16 % mod); Add(f[z ^ 1][i + 1][j], -(ll)f[z][i][j] * inv2 % mod); Add(f[z ^ 1][i][j + 1], -(ll)f[z][i][j] * inv2 % mod); Add(f[z ^ 1][i + 2][j], f[z][i][j]); Add(f[z ^ 1][i][j + 2], f[z][i][j]); Add(f[z ^ 1][i + 2][j + 1], -(ll)f[z][i][j] * 2ll % mod); Add(f[z ^ 1][i + 1][j + 2], -(ll)f[z][i][j] * 2ll % mod); Add(f[z ^ 1][i + 1][j + 1], f[z][i][j] * 2ll % mod); Add(f[z ^ 1][i + 2][j + 2], f[z][i][j]); f[z][i][j] = 0; } } z ^= 1; } int ans = 0; for (int i = 0; i <= n; i++) for (int j = 0; j <= n; j++) if (!j && !i && (!(n & 1))) Add(ans, mul[n * (n - 2) + 2]); else Add(ans, (ll)f[z][i][j] * mul[(n - i) * (n - j)] % mod); return ans; } int main() { scanf("%d", &n); C[0][0] = 1; for (int i = 1; i <= n; i++) { C[i][0] = 1; for (int j = 1; j <= i; j++) C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % mod; } mul[0] = 1; for (int i = 1; i <= n * n + 1; i++) mul[i] = mul[i - 1] * 2 % mod; int ans = Solve0(); ans = (ans - 2ll * Solve1()) % mod; ans = (ans + Solve2()) % mod; if (ans < 0)ans += mod; printf("%d\n", ans); }