题意:(复制sunset的)有$T$天,每天有$K$个小时,第$i$天有$D+i−1$道菜,第一个小时你选择$L$道菜吃,接下来每个小时你可以选择吃一道菜或者选择$A$个活动中的一个参加,不能连续两个小时吃菜,问每天的方案数之和。$K$,$A$预先给定,$Q$次询问,每次给$D$,$L$,$T$。
题解:显然$ans=\sum_{i=D}^{D+T-1}\binom{i}{L}F(i)$,其中$F(i)$是一个不超过$k-1$次的多项式。
把组合数暴力拆开,变为$\sum_{i=D}^{D+T-1}\frac{i!}{L!(i-L)!}F(i)$。因为有阶乘,所以考虑把$F(i)$写成上升幂多项式的形式来消掉阶乘。具体地,设$F(x)=\sum_{i=0}^{k-1}a_i(x+1)\dots(x+i)=\sum_{i=0}^{k-1}a_i\frac{(x+i)!}{x!}$,则$ans=\frac{1}{L!}\sum_{i=D}^{D+T-1}\sum_{j=0}^{k-1}a_j\frac{(i+j)!}{(i-L)!}$。考虑在$\frac{(i+j)!}{(i-L)!}$的分母处补上$(j+L)!$变为组合数,则$ans=\frac{1}{L!} \sum_{j=0}^{k-1}a_j(j+L)!\sum_{i=D}^{D+T-1}\binom{i+j}{j+L}$。后面是组合数上指标求和,可以$O(1)$计算。
剩下的问题是怎样求$a$。上升幂多项式可以考虑用连续点值来求。具体地,假设我们求出了$F(-1),F(-2),\dots,F(-k)$,显然有式子$F(-u)=\sum_{i=0}^{u-1}\frac{(u-1)!}{(u-1-i)!}(-1)^ia_i$。设$x_i=(-1)^ia_i,y_i=\frac{1}{i!},z_i=F(-(u+1))$,则$Z=X*Y,X=\frac{Z}{Y}$。多项式求逆即可。(其实可以不用求逆,可以发现$Y=e^x,Y^{-1}=e^{-x}$。)
剩下的问题是怎样求点值。设$b_i$为考虑了前$i$个小时的方案数,对于要求的点值$x$,有递推式$b_i=Ab_{i-1}+Axb_{i-2}$,可以用矩阵快速幂在$O(\log k)$的时间内求出单个点值。
#include<bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
const int N = 1e6 + 10;
const int M = 1e7 + 1e5 + 10;
const db pi = acos(-1);
int k, a, mod, q, l, r[N], fac[M], inv[M], ifac[M], x[N], y[N], z[N];
int gi() {
int x = 0, o = 1;
char ch = getchar();
while((ch < '0' || ch > '9') && ch != '-') {
ch = getchar();
}
if(ch == '-') {
o = -1, ch = getchar();
}
while(ch >= '0' && ch <= '9') {
x = x * 10 + ch - '0', ch = getchar();
}
return x * o;
}
struct com {
db x, y;
com(db x = 0, db y = 0): x(x), y(y) {}
com operator+(const com &A) const {
return com(x + A.x, y + A.y);
}
com operator-(const com &A) const {
return com(x - A.x, y - A.y);
}
com operator*(const com &A) const {
return com(x * A.x - y * A.y, x * A.y + y * A.x);
}
com conj() {
return com(x, -y);
}
} w[N];
void init(int n) {
l = 0;
for(int i = 1; i < n; i <<= 1) {
++l;
}
for(int i = 0; i < n; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)), w[i] = com(cos(pi * i / n), sin(pi * i / n));
}
}
void FFT(com *a, int n) {
for(int i = 0; i < n; i++) if(i < r[i]) {
swap(a[i], a[r[i]]);
}
for(int i = 1; i < n; i <<= 1)
for(int p = i << 1, j = 0; j < n; j += p)
for(int k = 0; k < i; k++) {
com x = a[j + k], y = w[n / i * k] * a[j + k + i];
a[j + k] = x + y, a[j + k + i] = x - y;
}
}
void mul(int *a, int *b, int *c, int n) {
static com s1[N], s2[N], s3[N], s4[N], s5[N], s6[N];
init(n);
for(int i = 0; i < n; i++) {
s1[i] = com(a[i] & 32767, a[i] >> 15);
s2[i] = com(b[i] & 32767, b[i] >> 15);
}
FFT(s1, n), FFT(s2, n);
for(int i = 0; i < n; i++) {
int j = (n - 1) & (n - i);
com da = (s1[i] + s1[j].conj()) * com(0.5, 0);
com db = (s1[i] - s1[j].conj()) * com(0, -0.5);
com dc = (s2[i] + s2[j].conj()) * com(0.5, 0);
com dd = (s2[i] - s2[j].conj()) * com(0, -0.5);
s3[i] = da * dc, s4[i] = da * dd, s5[i] = db * dc, s6[i] = db * dd;
}
for(int i = 0; i < n; i++) {
s1[i] = s3[i] + s4[i] * com(0, 1);
s2[i] = s5[i] + s6[i] * com(0, 1);
}
FFT(s1, n), FFT(s2, n);
reverse(s1 + 1, s1 + n), reverse(s2 + 1, s2 + n);
for(int i = 0; i < n; i++) {
int da = (ll)(s1[i].x / n + 0.5) % mod;
int db = (ll)(s1[i].y / n + 0.5) % mod;
int dc = (ll)(s2[i].x / n + 0.5) % mod;
int dd = (ll)(s2[i].y / n + 0.5) % mod;
c[i] = (da + ((ll)(db + dc) << 15) + ((ll)dd << 30)) % mod;
}
}
struct mat {
int v[2][2];
mat() {
memset(v, 0, sizeof(v));
}
mat operator*(const mat &A) const {
mat ret;
for(int i = 0; i < 2; i++)
for(int j = 0; j < 2; j++) {
ull tmp = 0;
for(int k = 0; k < 2; k++) {
tmp += 1ll * v[i][k] * A.v[k][j];
}
ret.v[i][j] = tmp % mod;
}
return ret;
}
} S, T;
mat qpow(mat a, int b) {
mat ret;
for(int i = 0; i < 2; i++) {
ret.v[i][i] = 1;
}
while(b) {
if(b & 1) {
ret = ret * a;
}
a = a * a, b >>= 1;
}
return ret;
}
void init() {
const int n = 1e7 + 1e5 + 1;
fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = 1;
for(int i = 2; i <= n; i++) {
fac[i] = 1ll * fac[i - 1] * i % mod;
inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
ifac[i] = 1ll * ifac[i - 1] * inv[i] % mod;
}
}
int C(int n, int m) {
if(m < 0 || n < m) {
return 0;
}
return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
int main() {
#ifndef ONLINE_JUDGE
freopen("a.in", "r", stdin);
freopen("a.out", "w", stdout);
#endif
cin >> k >> a >> mod >> q;
init();
S.v[0][0] = 1, S.v[0][1] = a, T.v[1][0] = 1, T.v[1][1] = a;
for(int i = 0; i < k; i++) {
T.v[0][1] = 1ll * a * (mod - i - 1) % mod;
z[i] = 1ll * (S * qpow(T, k - 1)).v[0][0] * ifac[i] % mod;
y[i] = 1ll * ((i & 1) ? mod - 1 : 1) * ifac[i] % mod;
}
int N = 1;
while(N <= 2 * k - 2) {
N <<= 1;
}
mul(y, z, x, N);
for(int i = 0; i < k; i++) {
x[i] = 1ll * x[i] * ((i & 1) ? mod - 1 : 1) % mod;
}
while(q--) {
int l = gi(), d = gi(), t = gi(), ans = 0;
for(int i = 0; i < k; i++) {
ans = (ans + 1ll * x[i] * fac[i + l] % mod * (C(d + t + i, i + l + 1) - C(d + i, i + l + 1) + mod)) % mod;
}
cout << 1ll * ans*ifac[l] % mod << '\n';
}
return 0;
}
来源:oschina
链接:https://my.oschina.net/u/4277479/blog/3524889