程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
 程式師世界 >> 編程語言 >> C語言 >> C++ >> C++入門知識 >> codeoforces 271 D (後綴自動機 SAM)

codeoforces 271 D (後綴自動機 SAM)

編輯:C++入門知識

題意:給你一個字符串S,然後定義每一個字符是”好的‘或是”壞的“,求S中包含不超過k個壞字符的不同字串的個數。

 

思路:這道題可以用哈希和SET過的,但是太慢啦,我覺得正解應該是SA或者SAM,下面介紹SAM的做法。

我們構造S的SAM,然後在SAM的每一個狀態維護sum,表示該狀態 下的子串包含多少個”不好“的字符,po表示該狀態所表示的子串出現的位置中的一個(隨便哪一個)。我們再將SAM進行拓撲排序,然後自頂下下遍歷,我們遍歷到一個狀態p的時候,我們檢查該狀態的par節點的sum值,若sum已經超過k,則顯然這個狀態的所有子串均不滿足要求,,我們不妨把p的sum設為k+1,然後繼續遍歷下一個節點,否則,我們設tmp=sum,mi為p表示的子串的最小長度(p->par->val+1),ma為p所表示的子串的最大長度(p->val),由小到大開始枚舉每一個子串,即從p->po-mi+1到p->po-ma+1,若發現一個”好的“字符,則ans+=1,否則tmp++,若tmp超過了k,則設p->sum=k+1,跳過該狀態,否則ans+=1,最後設sum=tmp。繼續遍歷下一個狀態。最後我們輸出ans即可。代碼如下:

[cpp] 
#include <iostream>  
#include <stdio.h>  
#include <string.h>  
#include <algorithm>  
#define maxn 3010  
#define Smaxn 26  
using namespace std; 
struct node 

    node *par,*go[Smaxn]; 
    int po; 
    int sum; 
    int val; 
}*root,*tail,que[maxn],*top[maxn]; 
int tot; 
char str[maxn>>1]; 
int vis[26]; 
void add(int c,int l,int po) 

    node *p=tail,*np=&que[tot++]; 
    np->val=l; 
    np->po=po; 
    while(p&&p->go[c]==NULL) 
    p->go[c]=np,p=p->par; 
    if(p==NULL) np->par=root; 
    else 
    { 
        node *q=p->go[c]; 
        if(p->val+1==q->val) np->par=q; 
        else 
        { 
            node *nq=&que[tot++]; 
            *nq=*q; 
            nq->val=p->val+1; 
            np->par=q->par=nq; 
            while(p&&p->go[c]==q) p->go[c]=nq,p=p->par; 
        } 
    } 
    tail=np; 

int c[maxn],len; 
void init() 

    len=1; 
    tot=0; 
    memset(que,0,sizeof(que)); 
    root=tail=&que[tot++]; 

void solve(int limit) 

    int i,j; 
    memset(c,0,sizeof(c)); 
    for(i=0;i<tot;i++) 
    c[que[i].val]++; 
    for(i=1;i<len;i++) 
    c[i]+=c[i-1]; 
    for(i=0;i<tot;i++) 
    top[--c[que[i].val]]=&que[i]; 
    int sum=0; 
    for(i=1;i<tot;i++) 
    { 
        node *p=top[i]; 
        if(p->par->sum>limit) 
        { 
            p->sum=limit+1; 
            continue; 
        } 
        int mi=p->par->val+1,ma=p->val,tmp=p->par->sum,po=p->po; 
        for(j=mi;j<=ma;j++) 
        { 
            if(vis[str[po-j+1]-'a']) 
            { 
                tmp++; 
                if(tmp>limit) 
                { 
                    break; 
                } 
                else 
                sum++; 
            } 
            else 
            sum++; 
        } 
        p->sum=tmp; 
    } 
    printf("%d\n",sum); 

int main() 

    //freopen("dd.txt","r",stdin);  
    scanf("%s",str); 
    int i,k,l=strlen(str); 
    init(); 
    for(i=0;i<l;i++) 
    { 
        add(str[i]-'a',len++,i); 
    } 
    char tmp[26]; 
    scanf("%s",tmp); 
    for(i=0;i<26;i++) 
    vis[i]=1-(tmp[i]-'0'); 
    scanf("%d",&k); 
    solve(k); 
    return 0; 

#include <iostream>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#define maxn 3010
#define Smaxn 26
using namespace std;
struct node
{
    node *par,*go[Smaxn];
    int po;
    int sum;
    int val;
}*root,*tail,que[maxn],*top[maxn];
int tot;
char str[maxn>>1];
int vis[26];
void add(int c,int l,int po)
{
    node *p=tail,*np=&que[tot++];
    np->val=l;
    np->po=po;
    while(p&&p->go[c]==NULL)
    p->go[c]=np,p=p->par;
    if(p==NULL) np->par=root;
    else
    {
        node *q=p->go[c];
        if(p->val+1==q->val) np->par=q;
        else
        {
            node *nq=&que[tot++];
            *nq=*q;
            nq->val=p->val+1;
            np->par=q->par=nq;
            while(p&&p->go[c]==q) p->go[c]=nq,p=p->par;
        }
    }
    tail=np;
}
int c[maxn],len;
void init()
{
    len=1;
    tot=0;
    memset(que,0,sizeof(que));
    root=tail=&que[tot++];
}
void solve(int limit)
{
    int i,j;
    memset(c,0,sizeof(c));
    for(i=0;i<tot;i++)
    c[que[i].val]++;
    for(i=1;i<len;i++)
    c[i]+=c[i-1];
    for(i=0;i<tot;i++)
    top[--c[que[i].val]]=&que[i];
    int sum=0;
    for(i=1;i<tot;i++)
    {
        node *p=top[i];
        if(p->par->sum>limit)
        {
            p->sum=limit+1;
            continue;
        }
        int mi=p->par->val+1,ma=p->val,tmp=p->par->sum,po=p->po;
        for(j=mi;j<=ma;j++)
        {
            if(vis[str[po-j+1]-'a'])
            {
                tmp++;
                if(tmp>limit)
                {
                    break;
                }
                else
                sum++;
            }
            else
            sum++;
        }
        p->sum=tmp;
    }
    printf("%d\n",sum);
}
int main()
{
    //freopen("dd.txt","r",stdin);
    scanf("%s",str);
    int i,k,l=strlen(str);
    init();
    for(i=0;i<l;i++)
    {
        add(str[i]-'a',len++,i);
    }
    char tmp[26];
    scanf("%s",tmp);
    for(i=0;i<26;i++)
    vis[i]=1-(tmp[i]-'0');
    scanf("%d",&k);
    solve(k);
    return 0;
}

 

  1. 上一頁:
  2. 下一頁:
Copyright © 程式師世界 All Rights Reserved