「题解」「SDOI2016」模式字符串

一道毒瘤的 点分治 题目。

题目链接:Luogu P4075/BZOJ 4598/LibreOJ 2065/SDOI2016 R2D1T1。

题目

题意简述

一棵树,每个点上有一个大写字母,给定一个仅包含大写字母的字符串,问有多少对 $u,v$(有序数对)满足从 $u$ 到 $v$ 的简单路径拼成的字符串是模式串重复整数倍的结果。

输入有多组数据,数据组数用 $C$ 表示。

数据范围

变量数据范围
$C$$1\leq C\leq 10$
$n$$1\leq n\leq 10^6$
$m$$1\leq m\leq 10^6$

时空限制

题目时间限制空间限制
Luogu P4075$1\text{s}$$$125\text{MiB}$
BZOJ 4598$10\text{s}$$128\text{MiB}$
LibreOJ 2065$1\text{s}$$256\text{MiB}$
SDOI2016 R2D1T1$?\text{s}$$?\text{MiB}$

题解

思路

考虑点分治。

将整个复制得到的字符串切成三部分。左边部分,中间字符,右边部分。将中间字符放在根节点部位,然后用桶统计答案(左、右半部分恰好在哪个位置)。然后就是常规的点分治。

另外,如果字符串长度不足 $n$,应当补齐,用哈希处理。左边顺序处理,右边倒序处理。

下面讲一些细节。

细节

字符串 Hash

inline void fix(reg char *val){
    while(*val){
        *val=*val-'A'+1; //提前预处理好字符串,减少运算量
        ++val;
    }
    return;
}

inline void copy(reg char *str,const int& m,const int& n){ //补足字符串到 n 位
    for(reg int i=m+1;i<=n;++i)
        str[i]=str[i-m];
    return;
}

const ull base=31;
ull basep[MAXN];
ull prehash[MAXN],sufhash[MAXN];

inline void gethash(reg char *str,reg ull *hash,const int& n){ //计算 Hash 值并储存好
    for(reg int i=1;i<=n;++i)
        hash[i]=hash[i-1]*base+str[i];
    return;
}

点分治 的 Calc 部分

inline void DFS(int ID,const int& father,const int& dep,reg int &Maxdep,const char& mid,reg ull Hash){
    Maxdep=max(Maxdep,dep); //记录最大深度,方便后期还原
    Hash+=val[ID]*basep[dep-1]; //统计当前 Hash 值
    if(Hash==prehash[dep]){ //左半部分匹配
        ++Tpre[dep%m]; //统计
        if(mid==pre[dep%m+1]) //如果恰好中间字符同时匹配
            ans+=sumsuf[m-dep%m-1]; //直接统计答案
    }
    if(Hash==sufhash[dep]){ //右半部分匹配
        ++Tsuf[dep%m]; //统计
        if(mid==suf[dep%m+1]) //中间匹配
            ans+=sumpre[m-dep%m-1]; //统计答案
    }
    for(reg int i=head[ID];i;i=Next[i])
        if(!del[to[i]]&&to[i]!=father)
            DFS(to[i],ID,dep+1,Maxdep,mid,Hash); //递归
    return;
}

inline void Calc(const int& ID){
    if(val[ID]==pre[1])
        ++sumpre[0]; //没有左半部分的特殊情况
    if(val[ID]==suf[1])
        ++sumsuf[0]; //没有右半部分的特殊情况
    int Tag=0;
    for(reg int i=head[ID];i;i=Next[i])
        if(!del[to[i]]){
            int Max=0;
            DFS(to[i],ID,1,Max,val[ID],0); //统计
            Tag=max(Tag,Max); //记录最大深度,方便还原
            for(reg int i=0;i<=Max;++i)
                sumpre[i]+=Tpre[i],sumsuf[i]+=Tsuf[i],Tpre[i]=Tsuf[i]=0; //累加统计数据
        }
    for(reg int i=0;i<=Tag;++i)
        sumpre[i]=sumsuf[i]=0; //还原
    return;
}

快速读入

另外这道题卡常数,所以要快速读入。

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++) //用 fread() 加快读入速度
static char buf[100000],*p1=buf,*p2=buf;

inline int read(void){ //读入整数,比较简单
    reg bool f=false;
    reg char ch=getchar();
    reg int res=0;
    while(ch<'0'||'9'<ch)f|=(ch=='-'),ch=getchar();
    while('0'<=ch&&ch<='9')res=10*res+ch-'0',ch=getchar();
    return f?-res:res;
}

inline void read(char* str){ //读入仅包含大写字母的字符串
    reg char ch=getchar();
    while(ch<'A'||'Z'<ch)ch=getchar();
    while('A'<=ch&&ch<='Z')(*str++)=ch,ch=getchar();
    return;
}

初始化

记得初始化不要漏掉某些部分。

inline void Init(void){ //初始化
    cnt=0; //邻接表
    ans=0; //答案
    memset(del,0,sizeof(del)); //点分治标记
    memset(head,0,sizeof(head)); //邻接表表头
    return;
}

至于为什么会有这么多细节,你们可以看看提交记录中我被卡了多少次。

代码

算法的渐近时间复杂度为 $\Theta(m+n+n\log_2^2n)$。很明显会超时,但我们要相信评测机,是可以过的(实际上是数据太水了)

函数传递参数用 const int& 更快,防止被卡。

#include<bits/stdc++.h>
using namespace std;
#define reg register
typedef unsigned long long ull;
#define INF 0X3F3F3F3F
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
static char buf[100000],*p1=buf,*p2=buf;
inline int read(void){
    reg bool f=false;
    reg char ch=getchar();
    reg int res=0;
    while(ch<'0'||'9'<ch)f|=(ch=='-'),ch=getchar();
    while('0'<=ch&&ch<='9')res=10*res+ch-'0',ch=getchar();
    return f?-res:res;
}
inline void read(char* str){
    reg char ch=getchar();
    while(ch<'A'||'Z'<ch)ch=getchar();
    while('A'<=ch&&ch<='Z')(*str++)=ch,ch=getchar();
    return;
}

const int MAXN=1000000+5;
const int MAXM=1000000+5;

inline void Init(void);
inline void Read(void);
inline void Work(void);

int C;

int main(void){
    reg int C=read();
    while(C--){
        Init();
        Read();
        Work();
    }
    return 0;
}

int n,m;
int cnt,head[MAXN],to[MAXN<<1],Next[MAXN<<1];
char val[MAXN];
char pre[MAXN],suf[MAXN];
ull ans;

inline void Add_Edge(const int& u,const int& v){
    Next[++cnt]=head[u];
    to[cnt]=v;
    head[u]=cnt;
    return;
}

inline void Add_Tube(const int& u,const int& v){
    Add_Edge(u,v);
    Add_Edge(v,u);
    return;
}

inline void fix(reg char *val){
    while(*val){
        *val=*val-'A'+1;
        ++val;
    }
    return;
}

inline void copy(reg char *str,const int& m,const int& n){
    for(reg int i=m+1;i<=n;++i)
        str[i]=str[i-m];
    return;
}

const ull base=31;
ull basep[MAXN];
ull prehash[MAXN],sufhash[MAXN];

inline void gethash(reg char *str,reg ull *hash,const int& n){
    for(reg int i=1;i<=n;++i)
        hash[i]=hash[i-1]*base+str[i];
    return;
}

inline void Read(void){
    n=read(),m=read();
    read(val+1);
    fix(val+1);
    basep[0]=1;
    for(reg int i=1;i<=n;++i)
        basep[i]=basep[i-1]*base;
    for(reg int i=1;i<n;++i){
        static int u,v;
        u=read(),v=read();
        Add_Tube(u,v);
    }
    read(pre+1);
    fix(pre+1);
    for(reg int i=1;i<=m;++i)
        suf[i]=pre[m-i+1];
    copy(pre,m,n);
    copy(suf,m,n);
    gethash(pre,prehash,n);
    gethash(suf,sufhash,n);
    return;
}

bool del[MAXN];
int size[MAXN],Max[MAXN];

inline void GetRoot(int ID,const int& father,const int& sum,reg int& root){
    size[ID]=1,Max[ID]=0;
    for(reg int i=head[ID];i;i=Next[i])
        if(!del[to[i]]&&to[i]!=father){
            GetRoot(to[i],ID,sum,root);
            size[ID]+=size[to[i]];
            Max[ID]=max(Max[ID],size[to[i]]);
        }
    Max[ID]=max(Max[ID],sum-size[ID]);
    if(Max[ID]<=Max[root])
        root=ID;
    return;
}

int Tpre[MAXN],Tsuf[MAXN];
int sumpre[MAXN],sumsuf[MAXN];

inline void DFS(int ID,const int& father,const int& dep,reg int &Maxdep,const char& mid,reg ull Hash){
    Maxdep=max(Maxdep,dep);
    Hash+=val[ID]*basep[dep-1];
    if(Hash==prehash[dep]){
        ++Tpre[dep%m];
        if(mid==pre[dep%m+1])
            ans+=sumsuf[m-dep%m-1];
    }
    if(Hash==sufhash[dep]){
        ++Tsuf[dep%m];
        if(mid==suf[dep%m+1])
            ans+=sumpre[m-dep%m-1];
    }
    for(reg int i=head[ID];i;i=Next[i])
        if(!del[to[i]]&&to[i]!=father)
            DFS(to[i],ID,dep+1,Maxdep,mid,Hash);
    return;
}

inline void Calc(const int& ID){
    if(val[ID]==pre[1])
        ++sumpre[0];
    if(val[ID]==suf[1])
        ++sumsuf[0];
    int Tag=0;
    for(reg int i=head[ID];i;i=Next[i])
        if(!del[to[i]]){
            int Max=0;
            DFS(to[i],ID,1,Max,val[ID],0);
            Tag=max(Tag,Max);
            for(reg int i=0;i<=Max;++i)
                sumpre[i]+=Tpre[i],sumsuf[i]+=Tsuf[i],Tpre[i]=Tsuf[i]=0;
        }
    for(reg int i=0;i<=Tag;++i)
        sumpre[i]=sumsuf[i]=0;
    return;
}

inline void Solve(int ID,const int& sum){
    if(sum<m)
        return;
    int root=0;
    Max[root=0]=INF;
    GetRoot(ID,-1,sum,root);
    del[root]=true;
    Calc(root);
    for(reg int i=head[root];i;i=Next[i])
        if(!del[to[i]])
            Solve(to[i],size[to[i]]<size[root]?size[to[i]]:sum-size[root]);
    return;
}

inline void Work(void){
    Solve(1,n);
    printf("%llu\n",ans);
    return;
}

inline void Init(void){
    cnt=0;
    ans=0;
    memset(del,0,sizeof(del));
    memset(head,0,sizeof(head));
    return;
}