题目链接
Unique Occurrences
题意
给定一棵含有 n n n个结点的树,树上的每条边都有一个权值。 f ( v , u ) f(v,u) f(v,u)表示 v v v到 u u u的简单路径上边权只出现一次的边权的个数。求 ∑ f ( v , u ) ( 1 ≤ v < u ≤ n ) \sum f(v,u)(1 \leq v < u \leq n) ∑f(v,u)(1≤v<u≤n)。
分析
思路一
可以发现边权之间互不影响,可以按边权讨论。对于边权 w w w,现在要解决的问题是树上只经过边权为 w w w的边一次的简单路径的数量,此问题可以用树形DP求解。 f [ x ] [ 0 ] f[x][0] f[x][0]表示从 x x x开始向下有 f [ x ] [ 0 ] f[x][0] f[x][0]条边不经过边权为 w w w的边, f [ x ] [ 1 ] f[x][1] f[x][1]表示从 x x x开始向下有 f [ x ] [ 1 ] f[x][1] f[x][1]条路径只经过边权为 w w w的边一次,状态转移方程如下,其中 W ( x , y ) W(x,y) W(x,y) 表示 x x x 到 y y y 的边权, x x x 是 y y y 的父结点。
{ f [ x ] [ 0 ] = ∑ ( f [ y ] [ 0 ] + 1 ) , W ( x , y ) ≠ w f [ x ] [ 1 ] + = ∑ ( f [ y ] [ 0 ] + 1 ) , W ( x , y ) = w f [ x ] [ 1 ] + = ∑ f [ y ] [ 1 ] , W ( x , y ) ≠ w \left\{\begin{array}{l}f\lbrack x\rbrack\lbrack0\rbrack={\textstyle\sum(f\lbrack y\rbrack\lbrack0\rbrack+1),W(x,y)\neq w}\\f\lbrack x\rbrack\lbrack1\rbrack+={\textstyle\sum(f\lbrack y\rbrack\lbrack0\rbrack+1),W(x,y)=w}\\f\lbrack x\rbrack\lbrack1\rbrack+={\textstyle\sum f\lbrack y\rbrack\lbrack1\rbrack,W(x,y)\neq w}\end{array}\right. ⎩ ⎨ ⎧f[x][0]=∑(f[y][0]+1),W(x,y)=wf[x][1]+=∑(f[y][0]+1),W(x,y)=wf[x][1]+=∑f[y][1],W(x,y)=w
在DP的同时统计所有符合要求的边的数量,每次DP的时间复杂度是 O ( n ) O(n) O(n),这样总的时间复杂度是 O ( n 2 ) O(n^2) O(n2)。考虑进行优化,对于一种边权,涉及到的点是有限的,不用建出完整的树,可以只对边权为 w w w的边包含的顶点建立虚树, n n n棵树的总点数是 O ( n ) O(n) O(n)的,建虚树需要用到倍增求LCA,总的时间复杂度为 O ( n l o g ( n ) ) O(nlog(n)) O(nlog(n))。
思路二
思路一是按边权讨论,保留边权为 w w w的边,其实也可以只删去边权为 w w w的边。当删去边权为 w w w的边时,原树分成了若干个连通块,假设连通块之间以边权为 w w w的虚边连接,那么答案就是所有虚边两侧连通块大小乘积再求和。具体实现时有两种方法,一种是分治+可撤销并查集,另一种是Link Cut Tree,前者时间复杂度是 O ( n l o g ( n ) ) O(nlog(n)) O(nlog(n)),后者时间复杂度是 O ( n l o g ( n ) ) O(nlog(n)) O(nlog(n))。
AC代码
虚树
typedef long long ll;
const int N=5e5+10;
const int M=2*N;
int head[N],e[M],ne[M],w[M],tot;
int a[N],d[N],sz[N],dfn[N],stk[N],f[N][21];
vector<pair<int,int>> vec[N];
map<int,int> col[N];
map<int,set<int>> mp;
int n,t,cnt,num,top;
ll g[N][2];
ll ans;
void add(int x,int y,int z)
{
e[++tot]=y,ne[tot]=head[x],w[tot]=z,head[x]=tot;
}
void dfs(int x)
{
dfn[x]=++num;
sz[x]=1;
for(int i=head[x];i;i=ne[i])
{
int y=e[i];
if(!d[y])
{
d[y]=d[x]+1;
f[y][0]=x;
for(int j=1;j<=t;j++) f[y][j]=f[f[y][j-1]][j-1];
dfs(y);
sz[x]+=sz[y];
}
}
}
int getlca(int x,int y)
{
if(d[x]>d[y]) swap(x,y);
for(int i=t;i>=0;i--)
if(d[f[y][i]]>=d[x])
y=f[y][i];
if(x==y) return x;
for(int i=t;i>=0;i--)
if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}
bool cmp(int x,int y)
{
return dfn[x]<dfn[y];
}
void build(int k)
{
cnt=0;
for(auto x:mp[k]) if(x!=1) a[++cnt]=x;
sort(a+1,a+cnt+1,cmp);
top=0; stk[++top]=1;
vec[1].clear();
for(int i=1;i<=cnt;i++)
{
int lca=getlca(a[i],stk[top]);
if(lca!=stk[top])
{
while(dfn[lca]<dfn[stk[top-1]])
{
int x=stk[top-1],y=stk[top];
vec[x].push_back({y,col[x].count(y)?col[x][y]:0});
top--;
}
if(dfn[lca]>dfn[stk[top-1]])
{
vec[lca].clear();
vec[lca].push_back({stk[top],col[lca].count(stk[top])?col[lca][stk[top]]:0});
top--;
stk[++top]=lca;
}
else
{
int x=stk[top-1],y=stk[top];
vec[x].push_back({y,col[x].count(y)?col[x][y]:0});
top--;
}
}
vec[a[i]].clear();
stk[++top]=a[i];
}
while(top>1)
{
int x=stk[top-1],y=stk[top];
vec[x].push_back({y,col[x].count(y)?col[x][y]:0});
top--;
}
}
void dp(int x,int p,int k)
{
g[x][1]=0; g[x][0]=sz[x]-1;
for(auto it:vec[x]) g[x][0]-=sz[it.first];
for(auto it:vec[x])
{
int y=it.first,z=it.second;
if(y==p) continue;
dp(y,x,k);
if(z==k)
{
ans+=(g[y][0]+1)*g[x][0];
g[x][1]+=(g[y][0]+1);
}
else
{
ans+=(g[y][0]+1)*g[x][1];
ans+=g[y][1]*g[x][0];
g[x][0]+=(g[y][0]+1);
g[x][1]+=g[y][1];
}
}
ans+=g[x][1];
}
int main()
{
cin>>n;
for(int i=1;i<n;i++)
{
int x,y,z;
cin>>x>>y>>z;
add(x,y,z);
add(y,x,z);
col[x][y]=col[y][x]=z;
mp[z].insert(x);
mp[z].insert(y);
}
while((1<<t)<n) t++;
d[1]=1; dfs(1);
for(int k=1;k<=n;k++)
{
if(mp[k].size())
{
build(k);
dp(1,0,k);
}
}
cout<<ans<<endl;
return 0;
}
分治+可撤销并查集
map<int,vector<pair<int,int>>> mp;
ll ans;
int n;
void dfs(int l,int r)
{
if(l==r)
{
for(auto it:mp[l])
{
ans+=(ll)dsu.siz(it.first)*dsu.siz(it.second);
}
return ;
}
int mid=(l+r)>>1;
int h=dsu.histroy();
for(int i=l;i<=mid;i++)
{
for(auto it:mp[i])
{
dsu.merge(it.first,it.second);
}
}
dfs(mid+1,r);
dsu.roll(h);
for(int i=mid+1;i<=r;i++)
{
for(auto it:mp[i])
{
dsu.merge(it.first,it.second);
}
}
dfs(l,mid);
dsu.roll(h);
}
int main()
{
cin>>n;
dsu.init(n);
for(int i=1;i<n;i++)
{
int x,y,z;
cin>>x>>y>>z;
mp[z].push_back({x,y});
}
dfs(1,n);
cout<<ans<<endl;
return 0;
}
LCT
const int N=5e5+10;
vector<pair<int,int>> vec[N];
int main()
{
int n; cin>>n;
for(int i=1;i<=n;i++) lct.tr[i].sz=1;
for(int i=1;i<n;i++)
{
int x,y,z;
cin>>x>>y>>z;
lct.link(x,y);
vec[z].push_back({x,y});
}
ll ans=0;
for(int i=1;i<=n;i++)
{
for(auto it:vec[i])
{
int x=it.first,y=it.second;
lct.cut(x,y);
}
for(auto it:vec[i])
{
int x=it.first,y=it.second;
lct.makeroot(x),lct.makeroot(y);
ans+=(ll)lct.tr[x].sz*lct.tr[y].sz;
}
for(auto it:vec[i])
{
int x=it.first,y=it.second;
lct.link(x,y);
}
}
cout<<ans<<endl;
return 0;
}