题面
给出一棵 $n$ 个点的无根树,请在这棵树上选三个互不相同的节点,使得这个三个节点两两之间距离相等,输出方案数即可。
输入格式
每个测试点包含多组测试数据。
对于每组数据,第一行一个整数 $n$,表示节点数。
接下来 $n-1$ 行每行两个整数 $x,y$,描述一条边。
输入以一行一个 $0$ 作为结束。
输出格式
对于每组数据,输出一个整数表示答案。
样例输入
7
1 2
5 7
2 5
2 3
5 6
4 5
7
1 2
2 3
1 4
4 5
1 6
6 7
0
样例输出
5
2
数据范围
Subtask1 20pts: $sum n \leq 500$
Subtask2 35pts: $sum n \leq 5000$
Subtask3 45pts: $sum n \leq 50000$
题解
DP部分(此部分不是这篇题解的重点)
在每个子树中存储一些信息
- f[i] 表示此子树中距离根节点距离为i的节点有多少个
- g[i] 表示此子树内的节点再加入子树外 一个 距离根节点距离为i的 节点就能组成满足题目条件的对数
最后我们用一个全局的 ans 来表示全局满足这样条件的三个节点有多少对
所以在遍历子树后,可以有以下式子来转移至父节点:
nd[u].g[0]+=nd[v].g[1]; // 和下面的式子一样
for(int i=1;i<=nd[v].mxdep;++i){
ans += nd[u].f[i] * nd[v].g[i+1] // 子树中选取两个 子树外选取一个
+ nd[v].f[i-1] * nd[u].g[i]; // 子树中选取一个 子树外选取两个
nd[u].g[i] += nd[v].g[i+1] // 继承子树
+ nd[v].f[i-1] * nd[u].f[i]; // 子树中选一个 子树外选一个
nd[u].f[i] += nd[v].f[i-1]; // 显然
}
ans+=nd[u].g[0]; // 将上面漏加的加上 (v子树中两个节点到v的距离与u-v距离相等)
可以拿到 55pts
长链剖分优化DP(重点部分)
这里可以发现,当加入第一个子树时,其实没有根节点中原有信息的干扰,而是相当于直接将子树信息拷贝到根节点。
所以使用长链剖分。类似于重链剖分,我们将最长/最重的那个子树直接移到根节点 $u$,这里运用指针可以将这个过程优化到O(1),相当于让子树直接在根节点 $f$ $g$ 数组的偏移量上进行计算;而对于其他子树则直接暴力合并,这样总体的复杂度为$\Theta (n)$。
可以看到第一个子树的f数组相对于根节点向右偏移了1位,而g数组则向左偏移了1位。
所以可以用指针完成这个偏移操作。
具体实现:f[id][i] 等同于原来编号id节点上的 f[i]
空间开辟
int *f[50050],*g[50050];
int mem[20000000],*alloc=mem; //2e7 可根据实测值进行调整
f、g 的开辟(用于1号节点及后续子树暴力合并)
g[v]=alloc+nd[v].mxdep+5; //偏移到中间位置方便后续再向左偏移
alloc+=nd[v].mxdep*2+10; //多预留一些空间免得出问题
f[v]=alloc;
alloc+=nd[v].mxdep*2+10;
f、g 的偏移
if(son){
f[son]=f[u]+1;
g[son]=g[u]-1;
dfs2(son,u);
}
总体代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n;
struct edge{
int to,nxt;
}ed[50050*2];int edcnt=0;
int hd[50050];
struct node{
int dep,mxdep,son;
}nd[50050];
inline void adde(int u,int v){
ed[++edcnt].to=v;
ed[edcnt].nxt=hd[u];
hd[u]=edcnt;
}
ll ans=0;
void dfs(int u,int fa){
nd[u].dep=nd[fa].dep+1;
nd[u].mxdep=1;
for(int id=hd[u];id;id=ed[id].nxt){
int v=ed[id].to;
if(v==fa)continue;
dfs(v,u);
if(nd[v].mxdep+1>nd[u].mxdep){
nd[u].mxdep=nd[v].mxdep+1;
nd[u].son=v;
}
}
}
int *f[50050],*g[50050];
int mem[20000000],*alloc=mem;
void dfs2(int u,int fa){
f[u][0]=1;
int son = nd[u].son;
if(son){
f[son]=f[u]+1;
g[son]=g[u]-1;
dfs2(son,u);
}
for(int id=hd[u];id;id=ed[id].nxt){
int v=ed[id].to;
if(v==fa||v==son)continue;
g[v]=alloc+nd[v].mxdep+5;
alloc+=nd[v].mxdep*2+10;
f[v]=alloc;
alloc+=nd[v].mxdep*2+10;
dfs2(v,u);
g[u][0]+=g[v][1];
for(int i=1;i<=nd[v].mxdep;++i){
ans+=(long long)f[u][i]*g[v][i+1]+(long long)f[v][i-1]*g[u][i];
g[u][i]+=g[v][i+1]+f[v][i-1]*f[u][i];
f[u][i]+=f[v][i-1];
}
alloc-=nd[v].mxdep*4+20;
memset(alloc,0,nd[v].mxdep*16+80);
}
ans+=g[u][0];
}
int main(){
scanf("%d",&n);
while(n){
memset(mem,0,4*(alloc-mem));
alloc=mem;
memset(nd,0,sizeof nd);
edcnt=0;
memset(hd,0,sizeof hd);
ans=0;
for(int i=1,u,v;i<n;++i){
scanf("%d%d",&u,&v);
adde(u,v);
adde(v,u);
}
dfs(1,1);
g[1]=alloc+nd[1].mxdep+5;
alloc+=2*nd[1].mxdep+10;
f[1]=alloc;
alloc+=2*nd[1].mxdep+10;
dfs2(1,1);
printf("%lld\n",ans);
scanf("%d",&n);
}
}
应该是g数组保持正序的写法中较快、空间占用较小的了(320ms 5800kb)
总结
长链剖分跑起来很快,通过奇妙的长链转移和短链暴力,只需要修改少量代码就能从 $\Theta (n^2)$ 变成 $\Theta (n)$
指针在思路清晰的情况下用起来很爽
- 可以让你几乎完全跳到内存的层面上进行思考
- 可以方便地回收空间以达到减少内存占用