BZOJ 4543 题解

题面

给出一棵 $n$ 个点的无根树,请在这棵树上选三个互不相同的节点,使得这个三个节点两两之间距离相等,输出方案数即可。

输入格式

每个测试点包含多组测试数据。

对于每组数据,第一行一个整数 $n$,表示节点数。

接下来 $n-1$ 行每行两个整数 $x,y$,描述一条边。

输入以一行一个 $0$ 作为结束。

输出格式

对于每组数据,输出一个整数表示答案。

样例输入

7
1 2
5 7
2 5
2 3
5 6
4 5
7
1 2
2 3
1 4
4 5
1 6
6 7
0

样例输出

5
2

数据范围

Subtask1 20pts: $sum n \leq 500$

Subtask2 35pts: $sum n \leq 5000$

Subtask3 45pts: $sum n \leq 50000$

题解

DP部分(此部分不是这篇题解的重点)

在每个子树中存储一些信息

  • f[i] 表示此子树中距离根节点距离为i的节点有多少个
  • g[i] 表示此子树内的节点再加入子树外 一个 距离根节点距离为i的 节点就能组成满足题目条件的对数

最后我们用一个全局的 ans 来表示全局满足这样条件的三个节点有多少对

所以在遍历子树后,可以有以下式子来转移至父节点:

nd[u].g[0]+=nd[v].g[1]; // 和下面的式子一样
for(int i=1;i<=nd[v].mxdep;++i){
    ans += nd[u].f[i] * nd[v].g[i+1] // 子树中选取两个 子树外选取一个
         + nd[v].f[i-1] * nd[u].g[i]; // 子树中选取一个 子树外选取两个
    nd[u].g[i] += nd[v].g[i+1] // 继承子树
                + nd[v].f[i-1] * nd[u].f[i]; // 子树中选一个 子树外选一个
    nd[u].f[i] += nd[v].f[i-1]; // 显然
}
ans+=nd[u].g[0]; // 将上面漏加的加上 (v子树中两个节点到v的距离与u-v距离相等)

可以拿到 55pts

长链剖分优化DP(重点部分)

这里可以发现,当加入第一个子树时,其实没有根节点中原有信息的干扰,而是相当于直接将子树信息拷贝到根节点。

所以使用长链剖分。类似于重链剖分,我们将最长/最重的那个子树直接移到根节点 $u$,这里运用指针可以将这个过程优化到O(1),相当于让子树直接在根节点 $f$ $g$ 数组的偏移量上进行计算;而对于其他子树则直接暴力合并,这样总体的复杂度为$\Theta (n)$。

可以看到第一个子树的f数组相对于根节点向右偏移了1位,而g数组则向左偏移了1位。

所以可以用指针完成这个偏移操作。

具体实现:f[id][i] 等同于原来编号id节点上的 f[i]

空间开辟

int *f[50050],*g[50050];
int mem[20000000],*alloc=mem; //2e7 可根据实测值进行调整

f、g 的开辟(用于1号节点及后续子树暴力合并)

g[v]=alloc+nd[v].mxdep+5; //偏移到中间位置方便后续再向左偏移
alloc+=nd[v].mxdep*2+10; //多预留一些空间免得出问题
f[v]=alloc;
alloc+=nd[v].mxdep*2+10;

f、g 的偏移

if(son){
    f[son]=f[u]+1;
    g[son]=g[u]-1;
    dfs2(son,u);
}

总体代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n;
struct edge{
    int to,nxt;
}ed[50050*2];int edcnt=0;
int hd[50050];
struct node{
    int dep,mxdep,son;
}nd[50050];
inline void adde(int u,int v){
    ed[++edcnt].to=v;
    ed[edcnt].nxt=hd[u];
    hd[u]=edcnt;
}
ll ans=0;
void dfs(int u,int fa){
    nd[u].dep=nd[fa].dep+1;
    nd[u].mxdep=1;
    for(int id=hd[u];id;id=ed[id].nxt){
        int v=ed[id].to;
        if(v==fa)continue;
        dfs(v,u);
        if(nd[v].mxdep+1>nd[u].mxdep){
            nd[u].mxdep=nd[v].mxdep+1;
            nd[u].son=v;
        }
    }
}
int *f[50050],*g[50050];
int mem[20000000],*alloc=mem;
void dfs2(int u,int fa){
    f[u][0]=1;
    int son = nd[u].son;
    if(son){
        f[son]=f[u]+1;
        g[son]=g[u]-1;
        dfs2(son,u);
    }
    for(int id=hd[u];id;id=ed[id].nxt){
        int v=ed[id].to;
        if(v==fa||v==son)continue;
        g[v]=alloc+nd[v].mxdep+5;
        alloc+=nd[v].mxdep*2+10;
        f[v]=alloc;
        alloc+=nd[v].mxdep*2+10;
        dfs2(v,u);
        g[u][0]+=g[v][1];
        for(int i=1;i<=nd[v].mxdep;++i){
            ans+=(long long)f[u][i]*g[v][i+1]+(long long)f[v][i-1]*g[u][i];
            g[u][i]+=g[v][i+1]+f[v][i-1]*f[u][i];
            f[u][i]+=f[v][i-1];
        }
        alloc-=nd[v].mxdep*4+20;
        memset(alloc,0,nd[v].mxdep*16+80);
    }
    ans+=g[u][0];
}
int main(){
    scanf("%d",&n);
    while(n){
        memset(mem,0,4*(alloc-mem));
        alloc=mem;
        memset(nd,0,sizeof nd);
        edcnt=0;
        memset(hd,0,sizeof hd);
        ans=0;
        for(int i=1,u,v;i<n;++i){
            scanf("%d%d",&u,&v);
            adde(u,v);
            adde(v,u);
        }
        dfs(1,1);
        g[1]=alloc+nd[1].mxdep+5;
        alloc+=2*nd[1].mxdep+10;
        f[1]=alloc;
        alloc+=2*nd[1].mxdep+10;
        dfs2(1,1);
        printf("%lld\n",ans);
        scanf("%d",&n);
    }
}

应该是g数组保持正序的写法中较快、空间占用较小的了(320ms 5800kb)

总结

长链剖分跑起来很快,通过奇妙的长链转移和短链暴力,只需要修改少量代码就能从 $\Theta (n^2)$ 变成 $\Theta (n)$

指针在思路清晰的情况下用起来很爽

  • 可以让你几乎完全跳到内存的层面上进行思考
  • 可以方便地回收空间以达到减少内存占用