问题描述:

给一棵树,每条边有权。求一条简单路径,权值和等于$K$,且边的数量最小。

输入:

第一行:两个整数$n,k$
第二至 $n$行:每行三个整数,表示一条无向边的两端和权值 (注意点的编号从$0$开始)

输出:

一个整数,表示最小边数量
如果不存在这样的路径,输出$-1$

思路:

对于静态统计树上路径的问题,点分治往往是一种可行的做法。这道题也就可以用点分治来做。

分治时对于根节点$p$,可以开一个数组$val$记录从$p$开始每个长度的链最少需要多少条边。需要注意的是,我们应该对每个子树先dfs一遍更新答案,再dfs一遍更新$val$数组。这是因为如果同时操作,计算时就会乱掉(可能两条链属于同一棵子树,这样两条链就有重复的边)。dfs时记录当前链的深度$dep$和边数$num$,就用$val[k-dep]+num$来更新答案。

代码:

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <algorithm>
#define MAXN 200005
#define MAXK 1000006
#define INF 0x3f3f3f3f
#define LL long long
using namespace std;

int n, k, cnt, tot, root, ans, head[MAXN], size[MAXN], val[MAXK], f[MAXN];
bool vis[MAXN];

struct Edge {int next, to, dis;} edge[MAXN*2];

void addedge(int from, int to, int dis)
{
    edge[++cnt].next=head[from];
    edge[cnt].to=to;
    edge[cnt].dis=dis;
    head[from]=cnt;
}

void getroot(int x, int fa)
{
    size[x]=1; f[x]=0; 
    for(int i=head[x]; i; i=edge[i].next)
    {
        int to=edge[i].to;
        if(to==fa || vis[to]) continue;
        getroot(to, x);
        size[x]+=size[to];
        f[x]=max(f[x], size[to]);
    }
    f[x]=max(f[x], tot-size[x]);
    if(f[x]<f[root]) root=x;
}

void getdis(int x, int fa, int dep, int num, int op)
{
    if(op==0) val[dep]=min(val[dep], num);
    else val[dep]=INF;
    for(int i=head[x]; i; i=edge[i].next)
    {
        int to=edge[i].to;
        if(to==fa || vis[to]) continue;
        if(dep+edge[i].dis<=k) getdis(to, x, dep+edge[i].dis, num+1, op); 
    }
}

void update(int x, int fa, int dep, int num)
{
    ans=min(ans, val[k-dep]+num);
    for(int i=head[x]; i; i=edge[i].next)
    {
        int to=edge[i].to;
        if(to==fa || vis[to]) continue;
        if(dep+edge[i].dis<=k) update(to, x, dep+edge[i].dis, num+1); 
    }
}

void divide(int x)
{
    vis[x]=1; val[0]=0;
    for(int i=head[x]; i; i=edge[i].next)
    {
        int to=edge[i].to;
        if(vis[to]) continue;
        if(edge[i].dis<=k) update(to, x, edge[i].dis, 1);
        if(edge[i].dis<=k) getdis(to, x, edge[i].dis, 1, 0);
    }
    for(int i=head[x]; i; i=edge[i].next)
    {
        int to=edge[i].to;
        if(vis[to]) continue;
        if(edge[i].dis<=k) getdis(to, x, edge[i].dis, 1, 1);
    }
    for(int i=head[x]; i; i=edge[i].next)
    {
        int to=edge[i].to;
        if(vis[to]) continue;
        tot=size[to]; root=0;
        getroot(to, x);
        divide(root);
    }
}

int main()
{
    f[0]=INF; ans=INF;
    memset(val, 0x3f, sizeof(val));

    scanf("%d%d", &n, &k);
    for(int i=1; i<n; i++)
    {
        int x, y, z;
        scanf("%d%d%d", &x, &y, &z);
        x++; y++;
        addedge(x, y, z);
        addedge(y, x, z);
    }
    tot=n; root=0;
    getroot(1, 0);
    divide(root);
    if(ans==INF) printf("-1\n");
    else printf("%d\n", ans);
 }

发表评论

邮箱地址不会被公开。 必填项已用*标注