poj1625
题意
给一个含n个字符的字符集,p个字符串,问长度为m的字符串有多少种不包含任意一个字符串pi。
题解
首先这题应该想到dp,dp[i][j]表示长度为i的字符串以节点j(节点j表示一个AC自动机上的一个状态节点)结尾的满足条件的字符串种类。以p个字符串建立AC自动机,标记危险节点。想像一个这样的问题:假设当前在u节点,u节点就表示当前长度为x的字符串,从u经过状态转移到其他节点,那么字符串的长度就变为x+1,遍历每个节点,如果该节点不是危险节点,则dp[i][j] = dp[i][j] + dp[i - 1][u],(j是u的一个子节点),这样从根结点(空串)转移m次到达的合法节点就是所求解。
注意:①数据很大,要用到大数模板
②字符有坑,会大于127,要用unsigned char 读入
③这里的AC自动机的模板还要改动一下,因为每个合法节点要进行状态转移,有些节点每个某个子节点,这样无法转移,所以将其的子节点指向它fail节点的对应的子节点(具体看代码注释)。
#include <iostream>
#include <cstdio>
#include <string.h>
#include <map>
#include <queue>
using namespace std;
typedef unsigned char uchar;
typedef long long ll;
const int mod = 100000;
const int N = 100 + 5;
const int MAXN = 1e2 + 10;
int trie[MAXN][55];
int tag[MAXN];
int fail[MAXN];
int cnt = 0;
int n, m, p;
int Hash[256], M;
void set_hash(int n, uchar s[]) {//这题字符有陷阱,要用unsigned char
M = n; for (int i = 0; i < n; i++) Hash[s[i]] = i;
}
void insertWords(uchar *s)
{
int root = 0;
for (int i = 0; s[i]; i++) {
int next = Hash[s[i]];
if(!trie[root][next])
trie[root][next] = ++cnt;
root = trie[root][next];
}
tag[root]++;//标记危险节点,危险节点就是该节点是模式串的最后一个节点
}
void getFail()//一个节点的fail指针是指向 这个节点表示的字符串的最长后缀串的最后一个节点
{
queue<int> q;
for(int i = 0; i < M; i++) {
if(trie[0][i]) {
fail[trie[0][i]] = 0;
q.push(trie[0][i]);
}
}
while (!q.empty())
{
int now = q.front();
q.pop();
if(tag[fail[now]]) tag[now] = 1;//这很重要***如果这个串的后缀是危险的,那么这个串也是危险的
for (int i = 0; i < M; i++) {
int u = trie[now][i];
if(!u) {
trie[now][i] = trie[fail[now]][i];//多了一种状态转移的方式,就是如果该节点下没有字符i,就连到fail节点的字符i
continue; //为什么呢?因为这样才能使得字符串长度增加一啊,不然走到这个节点长度不能增加就推不到长度m了
}
if(trie[now][i]) {
fail[trie[now][i]] = trie[fail[now]][i];
q.push(trie[now][i]);
}
else trie[now][i] = trie[fail[now]][i];
}
}
}
struct BigInteger{
int A[25];
enum{MOD = 10000};
BigInteger(){memset(A, 0, sizeof(A)); A[0]=1;}
void set(int x){memset(A, 0, sizeof(A)); A[0]=1; A[1]=x;}
void print(){
printf("%d", A[A[0]]);
for (int i=A[0]-1; i>0; i--){
if (A[i]==0){printf("0000"); continue;}
for (int k=10; k*A[i]<MOD; k*=10) printf("0");
printf("%d", A[i]);
}
printf("\n");
}
int& operator [] (int p) {return A[p];}
const int& operator [] (int p) const {return A[p];}
BigInteger operator + (const BigInteger& B){
BigInteger C;
C[0]=max(A[0], B[0]);
for (int i=1; i<=C[0]; i++)
C[i]+=A[i]+B[i], C[i+1]+=C[i]/MOD, C[i]%=MOD;
if (C[C[0]+1] > 0) C[0]++;
return C;
}
BigInteger operator * (const BigInteger& B){
BigInteger C;
C[0]=A[0]+B[0];
for (int i=1; i<=A[0]; i++)
for (int j=1; j<=B[0]; j++){
C[i+j-1]+=A[i]*B[j], C[i+j]+=C[i+j-1]/MOD, C[i+j-1]%=MOD;
}
if (C[C[0]] == 0) C[0]--;
return C;
}
};
uchar ss[100];
int main()
{
scanf("%d%d%d", &n, &m, &p);
cin >> ss;
set_hash(n, ss);
while (p--) {
cin >> ss;
insertWords(ss);
}
getFail();
BigInteger f[51][101];
f[0][0].set(1);//空串初始化为1
for (int i = 1; i <= m; i++) {
for (int j = 0; j <= cnt; j++) {
for (int k = 0; k < n; k++) {
int u = trie[j][k];
if(!tag[u]) f[i][u] = f[i][u] + f[i - 1][j];//不是危险节点就进行状态转移
}
}
}
BigInteger ans;
for (int i = 0; i <= cnt; i++)
if(!tag[i]) ans = ans + f[m][i];//对于不是危险节点的每个节点为结尾的 长度为m的字符串求和
ans.print();
}
来源:CSDN
作者:D_Bamboo_
链接:https://blog.csdn.net/D_Bamboo_/article/details/104704889