程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
 程式師世界 >> 編程語言 >> C語言 >> C++ >> C++入門知識 >> hdu 4416 Good Article Good sentence (後綴自動機 SAM)

hdu 4416 Good Article Good sentence (後綴自動機 SAM)

編輯:C++入門知識

2012杭州網絡賽的一道題,後綴數組後綴自動機都行吧。

題目大意:給一個字符串S和一系列字符串T1~Tn,問在S中有多少個不同子串滿足它不是T1~Tn中任意一個字符串的子串。

 

思路:我們先構造S的後綴自動機,然後將每一個Ti在S的SAM上做匹配,類似於LCS,在S中的每一個狀態記錄一個變量deep,表示T1~Tn,在該狀態能匹配的最大長度是多少,將每一個Ti匹配完之後,我們將S的SAM做拓撲排序,自底向上更新每個狀態的deep,同時計算在該狀態上有多少個子串滿足題目要求。具體步驟如下:

1:對於當前狀態,設為p,設p的par為q,則更新q->deep為q->deep和p->deep中的較大值。

2:若p->deep<p->val,則表示在狀態p中,長度為p->deep+1~p->val的子串不是T1~Tn中任意字符串的子串,所以答案加上p->val-p->deep。否則表示狀態p中所有字串均不滿足要求,跳過即可。

(注意若p->deep==0,表示狀態p中所有的子串均滿足題目要求,但是答案不是加上p->val-0,而是加上 p->val-p->par->val,這表示狀態p中的字符串個數,所以對於p->deep==0要特殊處理)

最後輸出答案即可。

代碼如下:

[cpp]
#include <iostream>  
#include <string.h>  
#include <stdio.h>  
#define maxn 200010  
#define Smaxn 26  
using namespace std; 
struct node 

    node *par,*go[Smaxn]; 
    int deep; 
    int val; 
}*root,*tail,que[maxn],*top[maxn]; 
int tot; 
char str[maxn>>1]; 
void add(int c,int l) 

    node *p=tail,*np=&que[tot++]; 
    np->val=l; 
    while(p&&p->go[c]==NULL) 
    p->go[c]=np,p=p->par; 
    if(p==NULL) np->par=root; 
    else 
    { 
        node *q=p->go[c]; 
        if(p->val+1==q->val) np->par=q; 
        else 
        { 
            node *nq=&que[tot++]; 
            *nq=*q; 
            nq->val=p->val+1; 
            np->par=q->par=nq; 
            while(p&&p->go[c]==q) p->go[c]=nq,p=p->par; 
        } 
    } 
    tail=np; 

int c[maxn],len; 
void init(int n) 

    int i; 
    for(i=0;i<=n;i++) 
    { 
        que[i].deep=que[i].val=0; 
        que[i].par=NULL; 
        memset(que[i].go,0,sizeof(que[i].go)); 
    } 
    tot=0; 
    len=1; 
    root=tail=&que[tot++]; 

int max(int a,int b) 

    return a>b?a:b; 

void solve(int q) 

    memset(c,0,sizeof(c)); 
    int i; 
    for(i=0;i<tot;i++) 
    c[que[i].val]++; 
    for(i=1;i<len;i++) 
    c[i]+=c[i-1]; 
    for(i=0;i<tot;i++) 
    top[--c[que[i].val]]=&que[i]; 
    while(q--) 
    { 
        node *p=root; 
        scanf("%s",str); 
        int l=strlen(str),tmp=0; 
        for(i=0;i<l;i++) 
        { 
             int x=str[i]-'a'; 
             if(p->go[x]) 
             { 
                 tmp++; 
                 p=p->go[x]; 
                 p->deep=max(p->deep,tmp); 
             } 
             else 
             { 
                 while(p&&p->go[x]==0) 
                 { 
                     p=p->par; 
                 } 
                 if(p) 
                 { 
                     tmp=p->val+1; 
                     p=p->go[x]; 
                     p->deep=max(tmp,p->deep); 
                 } 
                 else 
                 { 
                     tmp=0; 
                     p=root; 
                 } 
             } 
        } 
    } 
    long long sum=0; 
    for(i=tot-1;i>0;i--) 
    { 
        node *q=top[i]; 
        if(q->deep>0) 
        { 
            q->par->deep=max(q->par->deep,q->deep); 
            if(q->deep<q->val) 
            { 
                sum+=q->val-q->deep; 
            } 
        } 
        else 
        { 
            sum+=q->val-q->par->val; 
        } 
    } 
    printf("%I64d\n",sum); 

int main() 

    freopen("dd.txt","r",stdin); 
    int ncase,time=0; 
    scanf("%d",&ncase); 
    while(ncase--) 
    { 
        printf("Case %d: ",++time); 
        int n; 
        scanf("%d",&n); 
        scanf("%s",str); 
        int i,l=strlen(str); 
        init(l*2); 
        for(i=0;i<l;i++) 
        add(str[i]-'a',len++); 
        solve(n); 
    } 
    return 0; 

#include <iostream>
#include <string.h>
#include <stdio.h>
#define maxn 200010
#define Smaxn 26
using namespace std;
struct node
{
    node *par,*go[Smaxn];
    int deep;
    int val;
}*root,*tail,que[maxn],*top[maxn];
int tot;
char str[maxn>>1];
void add(int c,int l)
{
    node *p=tail,*np=&que[tot++];
    np->val=l;
    while(p&&p->go[c]==NULL)
    p->go[c]=np,p=p->par;
    if(p==NULL) np->par=root;
    else
    {
        node *q=p->go[c];
        if(p->val+1==q->val) np->par=q;
        else
        {
            node *nq=&que[tot++];
            *nq=*q;
            nq->val=p->val+1;
            np->par=q->par=nq;
            while(p&&p->go[c]==q) p->go[c]=nq,p=p->par;
        }
    }
    tail=np;
}
int c[maxn],len;
void init(int n)
{
    int i;
    for(i=0;i<=n;i++)
    {
        que[i].deep=que[i].val=0;
        que[i].par=NULL;
        memset(que[i].go,0,sizeof(que[i].go));
    }
    tot=0;
    len=1;
    root=tail=&que[tot++];
}
int max(int a,int b)
{
    return a>b?a:b;
}
void solve(int q)
{
    memset(c,0,sizeof(c));
    int i;
    for(i=0;i<tot;i++)
    c[que[i].val]++;
    for(i=1;i<len;i++)
    c[i]+=c[i-1];
    for(i=0;i<tot;i++)
    top[--c[que[i].val]]=&que[i];
    while(q--)
    {
        node *p=root;
        scanf("%s",str);
        int l=strlen(str),tmp=0;
        for(i=0;i<l;i++)
        {
             int x=str[i]-'a';
             if(p->go[x])
             {
                 tmp++;
                 p=p->go[x];
                 p->deep=max(p->deep,tmp);
             }
             else
             {
                 while(p&&p->go[x]==0)
                 {
                     p=p->par;
                 }
                 if(p)
                 {
                     tmp=p->val+1;
                     p=p->go[x];
                     p->deep=max(tmp,p->deep);
                 }
                 else
                 {
                     tmp=0;
                     p=root;
                 }
             }
        }
    }
    long long sum=0;
    for(i=tot-1;i>0;i--)
    {
        node *q=top[i];
        if(q->deep>0)
        {
            q->par->deep=max(q->par->deep,q->deep);
            if(q->deep<q->val)
            {
                sum+=q->val-q->deep;
            }
        }
        else
        {
            sum+=q->val-q->par->val;
        }
    }
    printf("%I64d\n",sum);
}
int main()
{
    freopen("dd.txt","r",stdin);
    int ncase,time=0;
    scanf("%d",&ncase);
    while(ncase--)
    {
        printf("Case %d: ",++time);
        int n;
        scanf("%d",&n);
        scanf("%s",str);
        int i,l=strlen(str);
        init(l*2);
        for(i=0;i<l;i++)
        add(str[i]-'a',len++);
        solve(n);
    }
    return 0;
}
 

 

 

 

  1. 上一頁:
  2. 下一頁:
Copyright © 程式師世界 All Rights Reserved