当前位置:编程学习 > C/C++ >>

codeoforces 271 D (后缀自动机 SAM)

题意:给你一个字符串S,然后定义每一个字符是”好的‘或是”坏的“,求S中包含不超过k个坏字符的不同字串的个数。

 

思路:这道题可以用哈希和SET过的,但是太慢啦,我觉得正解应该是SA或者SAM,下面介绍SAM的做法。

我们构造S的SAM,然后在SAM的每一个状态维护sum,表示该状态 下的子串包含多少个”不好“的字符,po表示该状态所表示的子串出现的位置中的一个(随便哪一个)。我们再将SAM进行拓扑排序,然后自顶下下遍历,我们遍历到一个状态p的时候,我们检查该状态的par节点的sum值,若sum已经超过k,则显然这个状态的所有子串均不满足要求,,我们不妨把p的sum设为k+1,然后继续遍历下一个节点,否则,我们设tmp=sum,mi为p表示的子串的最小长度(p->par->val+1),ma为p所表示的子串的最大长度(p->val),由小到大开始枚举每一个子串,即从p->po-mi+1到p->po-ma+1,若发现一个”好的“字符,则ans+=1,否则tmp++,若tmp超过了k,则设p->sum=k+1,跳过该状态,否则ans+=1,最后设sum=tmp。继续遍历下一个状态。最后我们输出ans即可。代码如下:

[cpp] 
#include <iostream>  
#include <stdio.h>  
#include <string.h>  
#include <algorithm>  
#define maxn 3010  
#define Smaxn 26  
using namespace std; 
struct node 

    node *par,*go[Smaxn]; 
    int po; 
    int sum; 
    int val; 
}*root,*tail,que[maxn],*top[maxn]; 
int tot; 
char str[maxn>>1]; 
int vis[26]; 
void add(int c,int l,int po) 

    node *p=tail,*np=&que[tot++]; 
    np->val=l; 
    np->po=po; 
    while(p&&p->go[c]==NULL) 
    p->go[c]=np,p=p->par; 
    if(p==NULL) np->par=root; 
    else 
    { 
        node *q=p->go[c]; 
        if(p->val+1==q->val) np->par=q; 
        else 
        { 
            node *nq=&que[tot++]; 
            *nq=*q; 
            nq->val=p->val+1; 
            np->par=q->par=nq; 
            while(p&&p->go[c]==q) p->go[c]=nq,p=p->par; 
        } 
    } 
    tail=np; 

int c[maxn],len; 
void init() 

    len=1; 
    tot=0; 
    memset(que,0,sizeof(que)); 
    root=tail=&que[tot++]; 

void solve(int limit) 

    int i,j; 
    memset(c,0,sizeof(c)); 
    for(i=0;i<tot;i++) 
    c[que[i].val]++; 
    for(i=1;i<len;i++) 
    c[i]+=c[i-1]; 
    for(i=0;i<tot;i++) 
    top[--c[que[i].val]]=&que[i]; 
    int sum=0; 
    for(i=1;i<tot;i++) 
    { 
        node *p=top[i]; 
        if(p->par->sum>limit) 
        { 
            p->sum=limit+1; 
            continue; 
        } 
        int mi=p->par->val+1,ma=p->val,tmp=p->par->sum,po=p->po; 
        for(j=mi;j<=ma;j++) 
        { 
            if(vis[str[po-j+1]-'a']) 
            { 
                tmp++; 
                if(tmp>limit) 
                { 
                    break; 
                } 
                else 
                sum++; 
            } 
            else 
            sum++; 
        } 
        p->sum=tmp; 
    } 
    printf("%d\n",sum); 

int main() 

    //freopen("dd.txt","r",stdin);  
    scanf("%s",str); 
    int i,k,l=strlen(str); 
    init(); 
    for(i=0;i<l;i++) 
    { 
        add(str[i]-'a',len++,i); 
    } 
    char tmp[26]; 
    scanf("%s",tmp); 
    for(i=0;i<26;i++) 
    vis[i]=1-(tmp[i]-'0'); 
    scanf("%d",&k); 
    solve(k); 
    return 0; 

#include <iostream>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#define maxn 3010
#define Smaxn 26
using namespace std;
struct node
{
    node *par,*go[Smaxn];
    int po;
    int sum;
    int val;
}*root,*tail,que[maxn],*top[maxn];
int tot;
char str[maxn>>1];
int vis[26];
void add(int c,int l,int po)
{
    node *p=tail,*np=&que[tot++];
    np->val=l;
    np->po=po;
    while(p&&p->go[c]==NULL)
    p->go[c]=np,p=p->par;
    if(p==NULL) np->par=root;
    else
    {
        node *q=p->go[c];
        if(p->val+1==q->val) np->pa

补充:软件开发 , C++ ,
CopyRight © 2022 站长资源库 编程知识问答 zzzyk.com All Rights Reserved
部分文章来自网络,