Codeforces 666E 题解

题意

给你一个串 $s$,和 $m$ 个串 $t_i$,要求你回答 $q$ 个询问,每个询问有四个参数 $l,r,pl,pr$,要求你在 $t_{[l,r]}$ 中找出 $s_{[pl,pr]}$ 的出现次数。

题解

类似于这种需要查询出现次数的,都应该联想到SAM。而我们学习SAM时,有一道类似的题,不过 $t$ 串只有一个。那道题我们是将 $t$ 串插入SAM,并将 $s_{[pl,pr]}$ 在SAM上找到,通过 $topo$ + 跳 $fail$ 来算出出现次数。

这道题的区别在于 $t$ 串有多个,且询问时需要在指定范围内查找。我们先解决一个简单版问题——假设 $l=1,r=m$。

那么就和那个简单题是一样的,不过我们要将所有串插入SAM。这里既可以拼成一个整串,然后用不同分隔符分开;也可以每次插入完一个串后将las跳回到root,在插入新字符时判断一下有没有现存的合适节点即可。然后 $topo$ 一下后,累加一个次数,最后只需要返回一个SAM中最短的包含了 $s_{[pl,pr]}$ 的节点上的次数即可。

那么没有了这个假设该怎么办呢?我们暴力的想,其实可以在每个SAM节点上存 $m$ 个次数,分别代表当前SAM节点代表的串在 $t_1,t_2,…,t_m$ 中的出现次数,$topo$ 后分别累加即可。

这样肯定是不能过的,所以换用权值线段树,然后将累加操作换为将当前节点merge到fail指向的节点。

所以就做完啦!

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
struct node{
    int ch[26],fa,len;
}d[555555];int las=1,tot=1;

struct segnode{
    int mx=0,mxid;
    segnode *lt,*rt;
    static segnode* niu();
    inline void pullup(){
        mx=0,mxid=0;
        if(lt){
            if(lt->mx>mx){mx=lt->mx,mxid=lt->mxid;}
            else if(lt->mx==mx&&lt->mxid<mxid){mxid=lt->mxid;}
        }
        if(rt){
            if(rt->mx>mx){mx=rt->mx,mxid=rt->mxid;}
            else if(rt->mx==mx&&rt->mxid<mxid){mxid=rt->mxid;}
        }
    }
    void add(int l,int r,int q){
        if(l==r){mxid=l;++mx;return;}
        int mid=(l+r)>>1;
        if(q<=mid){
            if(!lt)lt=niu();
            lt->add(l,mid,q);
        }else{
            if(!rt)rt=niu();
            rt->add(mid+1,r,q);
        }
        pullup();
    }
    segnode* merge(int l,int r,segnode *x, segnode *y){
        if(!x)return y;
        if(!y)return x;
        segnode *res=niu();
        if(l==r){
            res->mxid=x->mxid;res->mx=x->mx+y->mx;
            return res;
        }
        int mid=(l+r)>>1;
        res->lt=merge(l,mid,x->lt,y->lt);
        res->rt=merge(mid+1,r,x->rt,y->rt);
        res->pullup();
        return res;
    }
    pair<int,int> query(int l,int r,int ql,int qr){
        if(ql<=l&&r<=qr){return make_pair(mx,-mxid);}
        pair<int,int> res;
        int mid=(l+r)>>1;
        if(ql<=mid&&lt)res=max(res,lt->query(l,mid,ql,qr));
        if(qr>mid&&rt)res=max(res,rt->query(mid+1,r,ql,qr));
        return res;
    }
}nd[555555*30],*rt[555555];int allo=0;
inline segnode* segnode::niu(){return &nd[++allo];}
char s[505050];
int m;
char t[50505];
int pos[505050];
inline void add(char c,int i){
    if(d[las].ch[c]&&d[las].len+1==d[d[las].ch[c]].len){ // 存在和想要新增节点等价的节点则直接使用
        las=d[las].ch[c];
        if(!rt[las])rt[las]=segnode::niu();
        rt[las]->add(1,m,i);
        return;
    }
    int lastnode=las;int newnode=las=++tot;
    if(!rt[las])rt[las]=segnode::niu();
    rt[las]->add(1,m,i);
    d[newnode].len=d[lastnode].len+1;
    int p;
    for(p=lastnode;p&&!d[p].ch[c];p=d[p].fa)d[p].ch[c]=newnode;
    if(!p){
        d[newnode].fa=1;
        return;
    }
    int nearc=d[p].ch[c];
    if(d[nearc].len==d[p].len+1){
        d[newnode].fa=nearc;
        return;
    }
    int midnode=++tot;
    d[midnode]=d[nearc];d[midnode].len=d[p].len+1;
    d[newnode].fa=d[nearc].fa=midnode;
    for(int pp=p;pp&&d[pp].ch[c]==nearc;pp=d[pp].fa)d[pp].ch[c]=midnode;
}
int b[555555],topo[555555];
int maxlen=0;
inline void topoprep(){
    for(int i=1;i<=tot;++i)b[d[i].len]++;
    for(int i=1;i<=maxlen;++i)b[i]+=b[i-1];
    for(int i=1;i<=tot;++i)topo[b[d[i].len]--]=i;
}
int lenn[555555];
int main(){
    scanf("%s",s+1);
    scanf("%d",&m);
    for(int i=1;i<=m;++i){
        scanf("%s",t+1);
        int len=strlen(t+1);las=1;
        maxlen=max(maxlen,len);
        for(int j=1;j<=len;++j){
            add(t[j]-'a',i);    
            // cout<<las<<" ";   
        }
    }
    // cout<<endl;
    topoprep();
    for(int i=tot;i>=2;--i){ // 按topo序merge一下
        rt[d[topo[i]].fa]=rt[d[topo[i]].fa]->merge(1,m,rt[d[topo[i]].fa],rt[topo[i]]);
    }
    int slen=strlen(s+1);
    int cur=1,len=0;
    for(int i=1;i<=slen;++i){
        while(cur&&!d[cur].ch[s[i]-'a'])cur=d[cur].fa,len=d[cur].len;
        if(d[cur].ch[s[i]-'a']){
            ++len;
            cur=d[cur].ch[s[i]-'a'];
        }else{
            len=0;cur=1;
        }
        pos[i]=cur;lenn[i]=len;
    }
    int Q;
    scanf("%d",&Q);
    for(int i=1,l,r,pl,pr;i<=Q;++i){
        scanf("%d%d%d%d",&l,&r,&pl,&pr);
        int len=pr-pl+1;
        if(lenn[pr]<len){
            printf("%d 0\n",l);
            continue;
        }
        int cur=pos[pr];
        while(d[d[cur].fa].len>=len)cur=d[cur].fa;
        auto res = rt[cur]->query(1,m,l,r);
        if(res.first==0)printf("%d 0\n",l);
        else 
        printf("%d %d\n",-res.second,res.first);
    }
}