Poj1625 AC自动机+大数+DP

醉酒当歌 提交于 2020-03-07 03:09:39

poj1625
题意
给一个含n个字符的字符集,p个字符串,问长度为m的字符串有多少种不包含任意一个字符串pi。
题解
首先这题应该想到dp,dp[i][j]表示长度为i的字符串以节点j(节点j表示一个AC自动机上的一个状态节点)结尾的满足条件的字符串种类。以p个字符串建立AC自动机,标记危险节点。想像一个这样的问题:假设当前在u节点,u节点就表示当前长度为x的字符串,从u经过状态转移到其他节点,那么字符串的长度就变为x+1,遍历每个节点,如果该节点不是危险节点,则dp[i][j] = dp[i][j] + dp[i - 1][u],(j是u的一个子节点),这样从根结点(空串)转移m次到达的合法节点就是所求解。

注意:①数据很大,要用到大数模板
②字符有坑,会大于127,要用unsigned char 读入
③这里的AC自动机的模板还要改动一下,因为每个合法节点要进行状态转移,有些节点每个某个子节点,这样无法转移,所以将其的子节点指向它fail节点的对应的子节点(具体看代码注释)。

#include <iostream>
#include <cstdio>
#include <string.h>
#include <map>
#include <queue>
using namespace std;
typedef unsigned char uchar;
typedef long long ll;
const int mod = 100000;
const int N = 100 + 5;
const int MAXN = 1e2 + 10;
int trie[MAXN][55];
int tag[MAXN];
int fail[MAXN];
int cnt = 0;
int n, m, p;
int Hash[256], M;
void set_hash(int n, uchar s[]) {//这题字符有陷阱,要用unsigned char
    M = n; for (int i = 0; i < n; i++) Hash[s[i]] = i;
}
void insertWords(uchar *s)
{
    int root = 0;
    for (int i = 0; s[i]; i++) {
        int next = Hash[s[i]];
        if(!trie[root][next])
            trie[root][next] = ++cnt;
        root = trie[root][next];
    }
    tag[root]++;//标记危险节点,危险节点就是该节点是模式串的最后一个节点
}
void getFail()//一个节点的fail指针是指向 这个节点表示的字符串的最长后缀串的最后一个节点
{
    queue<int> q;
    for(int i = 0; i < M; i++) {
        if(trie[0][i]) {
            fail[trie[0][i]] = 0;
            q.push(trie[0][i]);
        }
    }
    while (!q.empty())
    {
        int now = q.front();
        q.pop();
        if(tag[fail[now]]) tag[now] = 1;//这很重要***如果这个串的后缀是危险的,那么这个串也是危险的
        for (int i = 0; i < M; i++) {
            int u = trie[now][i];
            if(!u) {
                trie[now][i] = trie[fail[now]][i];//多了一种状态转移的方式,就是如果该节点下没有字符i,就连到fail节点的字符i
                continue;                          //为什么呢?因为这样才能使得字符串长度增加一啊,不然走到这个节点长度不能增加就推不到长度m了
            }
            if(trie[now][i]) {
                fail[trie[now][i]] = trie[fail[now]][i];
                q.push(trie[now][i]);
            }
            else trie[now][i] = trie[fail[now]][i];
        }
    }
}

struct BigInteger{
    int A[25];
    enum{MOD = 10000};
    BigInteger(){memset(A, 0, sizeof(A)); A[0]=1;}
    void set(int x){memset(A, 0, sizeof(A)); A[0]=1; A[1]=x;}
    void print(){
        printf("%d", A[A[0]]);
        for (int i=A[0]-1; i>0; i--){
            if (A[i]==0){printf("0000"); continue;}
            for (int k=10; k*A[i]<MOD; k*=10) printf("0");
            printf("%d", A[i]);
        }
        printf("\n");
    }
    int& operator [] (int p) {return A[p];}
    const int& operator [] (int p) const {return A[p];}
    BigInteger operator + (const BigInteger& B){
        BigInteger C;
        C[0]=max(A[0], B[0]);
        for (int i=1; i<=C[0]; i++)
            C[i]+=A[i]+B[i], C[i+1]+=C[i]/MOD, C[i]%=MOD;
        if (C[C[0]+1] > 0) C[0]++;
        return C;
    }
    BigInteger operator * (const BigInteger& B){
        BigInteger C;
        C[0]=A[0]+B[0];
        for (int i=1; i<=A[0]; i++)
            for (int j=1; j<=B[0]; j++){
                C[i+j-1]+=A[i]*B[j], C[i+j]+=C[i+j-1]/MOD, C[i+j-1]%=MOD;
            }
        if (C[C[0]] == 0) C[0]--;
        return C;
    }
};

uchar ss[100];
int main()
{
    scanf("%d%d%d", &n, &m, &p);
    cin >> ss;
    set_hash(n, ss);
    while (p--) {
        cin >> ss;
        insertWords(ss);
    }
    getFail();
    BigInteger f[51][101];
    f[0][0].set(1);//空串初始化为1
    for (int i = 1; i <= m; i++) {
        for (int j = 0; j <= cnt; j++) {
            for (int k = 0; k < n; k++) {
                int u = trie[j][k];
                if(!tag[u]) f[i][u] = f[i][u] + f[i - 1][j];//不是危险节点就进行状态转移
            }
        }
    }
    BigInteger ans;
    for (int i = 0; i <= cnt; i++)
        if(!tag[i]) ans = ans + f[m][i];//对于不是危险节点的每个节点为结尾的 长度为m的字符串求和
    ans.print();
}

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!