问题描述:

给你一张有n个节点的图,其中有(n-1)条“主要边”构成一棵树,和m条“附加边”。你需要把这张图分为两个不连通的部分,你可以切除一条“主要边”和一条“附加边”,问你有多少种方案。

输入:

第一行两个整数n, m $(1<=m, n<=100000)$
接下来n-1行表示n-1条主要边
接下来m行表示附加边

输出:

输出方案总数

思路:

这道题目的思路比较巧妙。首先考虑当我们切除一条主要边时,把树分为了两部分,但这两部分可能有附加边让它们保持联通。显然,如果没有附加边让它们保持联通,那么就有m种方案(随便切一条就OK了)。如果有一条附加边让它们联通,那么就只有1种方案,如果有大于一条附加边时,就没有符合要求的方案。

我们可以记录每一条附加边两端之间的树上路径,那么路径上的每一条边都有一条附加边让它们联通。这时我们就可以利用树上差分+LCA统计出每一条主要边有多少条附加边让它们保持联通,最后算出答案。

代码:

#include <cstdio>
#include <algorithm>
#include <cstring>
#define MAXN 1000005
using namespace std;

int n, m, ans, cnt, id, head[MAXN], num[MAXN], tag[MAXN], dfn[MAXN], ver[MAXN], dep[MAXN], logn[MAXN], f[MAXN][30];

struct Edge {int next, to;} edge[MAXN*2];

void addedge(int from, int to)
{
    edge[++cnt].next=head[from];
    edge[cnt].to=to;
    head[from]=cnt;
}

void solve(int x, int fa)
{
    num[x]+=tag[x];
    for(int i=head[x]; i; i=edge[i].next)
    {
        int to=edge[i].to;
        if(to==fa) continue;
        solve(to, x);
        num[x]+=num[to];
    }
    if(x==1) return;
    if(num[x]==0) ans+=m;
    if(num[x]==1) ans++;
}

void dfs(int x, int fa, int deep)
{
    dfn[x]=++id; ver[id]=x; dep[id]=deep;
    for(int i=head[x]; i; i=edge[i].next)
    {
        int to=edge[i].to;
        if(to==fa) continue;
        dfs(to, x, deep+1);
        ver[++id]=x;
        dep[id]=deep;
    }
}

void init()
{
    logn[1]=0;
    for(int i=1; i<=n*2-1; i++) f[i][0]=i;
    for(int i=2; i<=n*2-1; i++) logn[i]=logn[i/2]+1;
    for(int i=1; (1<<i)<=n; i++)
    {
        for(int j=1; j+(1<<i)-1<=n*2-1; j++)
        {
            int x=f[j][i-1], y=f[j+(1<<(i-1))][i-1];
            if(dep[x]<dep[y]) f[j][i]=x;
            else f[j][i]=y;
        }
    }
}

int rmq(int l, int r)
{
    int k=logn[r-l+1];
    int x=f[l][k], y=f[r-(1<<k)+1][k];
    if(dep[x]<dep[y]) return ver[x];
    else return ver[y];
}

int lca(int x, int y)
{
    int l=dfn[x], r=dfn[y];
    if(l>r) swap(l, r);
    return rmq(l, r);
}

int main()
{
    int x, y;
    scanf("%d%d", &n, &m);
    for(int i=1; i<n; i++)
    {
        scanf("%d%d", &x, &y);
        addedge(x, y);
        addedge(y, x);
    }
    dfs(1, 0, 0);
    init();
    for(int i=1; i<=m; i++)
    {
        scanf("%d%d", &x, &y);
        tag[x]++; tag[y]++;
        tag[lca(x, y)]-=2;
    }
    solve(1, 0);
    printf("%d\n", ans);
}

发表评论

邮箱地址不会被公开。 必填项已用*标注