codeoforces 271 D (后缀自动机 SAM)
题意:给你一个字符串S,然后定义每一个字符是”好的‘或是”坏的“,求S中包含不超过k个坏字符的不同字串的个数。
思路:这道题可以用哈希和SET过的,但是太慢啦,我觉得正解应该是SA或者SAM,下面介绍SAM的做法。
我们构造S的SAM,然后在SAM的每一个状态维护sum,表示该状态 下的子串包含多少个”不好“的字符,po表示该状态所表示的子串出现的位置中的一个(随便哪一个)。我们再将SAM进行拓扑排序,然后自顶下下遍历,我们遍历到一个状态p的时候,我们检查该状态的par节点的sum值,若sum已经超过k,则显然这个状态的所有子串均不满足要求,,我们不妨把p的sum设为k+1,然后继续遍历下一个节点,否则,我们设tmp=sum,mi为p表示的子串的最小长度(p->par->val+1),ma为p所表示的子串的最大长度(p->val),由小到大开始枚举每一个子串,即从p->po-mi+1到p->po-ma+1,若发现一个”好的“字符,则ans+=1,否则tmp++,若tmp超过了k,则设p->sum=k+1,跳过该状态,否则ans+=1,最后设sum=tmp。继续遍历下一个状态。最后我们输出ans即可。代码如下:
[cpp]
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#define maxn 3010
#define Smaxn 26
using namespace std;
struct node
{
node *par,*go[Smaxn];
int po;
int sum;
int val;
}*root,*tail,que[maxn],*top[maxn];
int tot;
char str[maxn>>1];
int vis[26];
void add(int c,int l,int po)
{
node *p=tail,*np=&que[tot++];
np->val=l;
np->po=po;
while(p&&p->go[c]==NULL)
p->go[c]=np,p=p->par;
if(p==NULL) np->par=root;
else
{
node *q=p->go[c];
if(p->val+1==q->val) np->par=q;
else
{
node *nq=&que[tot++];
*nq=*q;
nq->val=p->val+1;
np->par=q->par=nq;
while(p&&p->go[c]==q) p->go[c]=nq,p=p->par;
}
}
tail=np;
}
int c[maxn],len;
void init()
{
len=1;
tot=0;
memset(que,0,sizeof(que));
root=tail=&que[tot++];
}
void solve(int limit)
{
int i,j;
memset(c,0,sizeof(c));
for(i=0;i<tot;i++)
c[que[i].val]++;
for(i=1;i<len;i++)
c[i]+=c[i-1];
for(i=0;i<tot;i++)
top[--c[que[i].val]]=&que[i];
int sum=0;
for(i=1;i<tot;i++)
{
node *p=top[i];
if(p->par->sum>limit)
{
p->sum=limit+1;
continue;
}
int mi=p->par->val+1,ma=p->val,tmp=p->par->sum,po=p->po;
for(j=mi;j<=ma;j++)
{
if(vis[str[po-j+1]-'a'])
{
tmp++;
if(tmp>limit)
{
break;
}
else
sum++;
}
else
sum++;
}
p->sum=tmp;
}
printf("%d\n",sum);
}
int main()
{
//freopen("dd.txt","r",stdin);
scanf("%s",str);
int i,k,l=strlen(str);
init();
for(i=0;i<l;i++)
{
add(str[i]-'a',len++,i);
}
char tmp[26];
scanf("%s",tmp);
for(i=0;i<26;i++)
vis[i]=1-(tmp[i]-'0');
scanf("%d",&k);
solve(k);
return 0;
}
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#define maxn 3010
#define Smaxn 26
using namespace std;
struct node
{
node *par,*go[Smaxn];
int po;
int sum;
int val;
}*root,*tail,que[maxn],*top[maxn];
int tot;
char str[maxn>>1];
int vis[26];
void add(int c,int l,int po)
{
node *p=tail,*np=&que[tot++];
np->val=l;
np->po=po;
while(p&&p->go[c]==NULL)
p->go[c]=np,p=p->par;
if(p==NULL) np->par=root;
else
{
node *q=p->go[c];
if(p->val+1==q->val) np->pa
补充:软件开发 , C++ ,