问题描述

设 $T$ 为一棵有根树,我们做如下的定义:

• 设 $a$ 和 $b$ 为 $T$ 中的两个不同节点。如果 $a$ 是 $b$ 的祖先,那么称「$a$ 比 $b$ 不知道高明到哪里去了」。

• 设 $a$ 和 $b$ 为 $T$ 中的两个不同节点。如果 $a$ 与 $b$ 在树上的距离不超过某个给定常数 $x$,那么称「$a$ 与 $b$ 谈笑风生」。

给定一棵 $n$ 个节点的有根树 $T$,节点的编号为 $1∼n$,根节点为 $1$ 号节点。你需要回答 $q$ 个询问,询问给定两个整数 $p $ 和 $k$,问有多少个有序三元组 $(a, b, c)$ 满足:

• $a$,$b$ 和 $c$ 为 $T$ 中三个不同的点,且 $a$ 为 $p$ 号节点;

• $a$ 和 $b$ 都比 $c$ 不知道高明到哪里去了;

• $a$ 和 $b$ 谈笑风生。这里谈笑风生中的常数为给定的 $k$。

输入

输入文件的第一行含有两个正整数 $n$ 和 $q$,分别代表有根树的点数与询问的个数。

接下来 $n−1$ 行,每行描述一条树上的边。每行含有两个整数 $u$ 和 $v$,代表在节点 $u$ 和 $v$ 之间有一条边。

接下来 $q$ 行,每行描述一个操作。第 $i$ 行含有两个整数,分别表示第 $i$ 个询问的 $p$ 和 $k$。

输出

输出 $q$ 行,每行对应一个询问,代表询问的答案。

思路

特别经典的一道题,大部分树上的数据结构都能够在这题使用。除了下面讲到都主席树和线段树合并的做法之外,还可以用树上启发式合并,长链剖分等做法去做。

我们将答案分成两个部分。第一部分是 $b$ 是 $a$ 的祖先,答案就是: $$min(dep[a]-1, k)*(siz[a]-1)$$

第二部分是 $a$ 是 $b$ 的祖先,答案就是:
$$\sum_{dep[b] \in [dep[a]+1, dep[a]+k]} (siz[b]-1)$$

于是我们不难想到以 $dep[x]$ 为关键字,$siz[x]-1$ 为权值建立权值线段树。对于每一个询问的答案就是在对应线段树上询问区间 $[dep[x]+1, dep[x]+k]$。实现时就可以用主席树或线段树合并来节省空间。

代码

主席树:

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#define MAXN 300005
#define INF 0x3f3f3f3f
#define rint register int
#define LL long long
#define LD long double
using namespace std;

int n, q, cnt, id, tot, head[MAXN], dep[MAXN], siz[MAXN], in[MAXN], out[MAXN], ver[MAXN], root[MAXN];

struct Edge {int next, to;} edge[MAXN*2];
struct Node {int ls, rs; LL val;} t[MAXN*40];

void addedge(int from, int to)
{
    edge[++cnt].next=head[from];
    edge[cnt].to=to;
    head[from]=cnt;
}

void dfs(int x, int fa)
{
    siz[x]=1; in[x]=++id; ver[id]=x;
    for(rint i=head[x]; i; i=edge[i].next)
    {
        int to=edge[i].to;
        if(to==fa) continue;
        dep[to]=dep[x]+1;
        dfs(to, x);
        siz[x]+=siz[to];
    }
    out[x]=id;
}

void update(int &rt1, int rt2, int l, int r, int x, int k)
{
    rt1=++tot; t[rt1]=t[rt2]; t[rt1].val+=k;
    if(l==r) return;
    int mid=(l+r)>>1;
    if(x<=mid) update(t[rt1].ls, t[rt2].ls, l, mid, x, k);
    else update(t[rt1].rs, t[rt2].rs, mid+1, r, x, k);
}

LL query(int rt, int l, int r, int x, int y)
{
    if(l>y || r<x) return 0;
    if(l>=x && r<=y) return t[rt].val;
    int mid=(l+r)>>1;
    return query(t[rt].ls, l, mid, x, y)+query(t[rt].rs, mid+1, r, x, y);
}

int main()
{
    scanf("%d%d", &n, &q);
    for(rint i=1, x, y; i<n; ++i)
    {
        scanf("%d%d", &x, &y);
        addedge(x, y);
        addedge(y, x);
    }
    dfs(1, 0);
    for(rint i=1; i<=n; ++i)
        update(root[i], root[i-1], 1, n, dep[ver[i]], siz[ver[i]]-1);
    while(q--)
    {
        int x, k;
        scanf("%d%d", &x, &k);
        LL a=1LL*min(dep[x], k)*(siz[x]-1);
        LL b=query(root[in[x]], 1, n, dep[x], dep[x]+k);
        LL c=query(root[out[x]], 1, n, dep[x], dep[x]+k);
        printf("%lld\n", a+c-b);
    }
    return 0;
}

线段树合并:

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#define MAXN 300005
#define INF 0x3f3f3f3f
#define rint register int
#define LL long long
#define LD long double
using namespace std;

int n, q, cnt, tot, head[MAXN], dep[MAXN], siz[MAXN], root[MAXN];
LL ans[MAXN];

struct Edge {int next, to;} edge[MAXN*2];
struct Node {int ls, rs; LL val;} t[MAXN*40];

void addedge(int from, int to)
{
    edge[++cnt].next=head[from];
    edge[cnt].to=to;
    head[from]=cnt;
}

void update(int &rt, int l, int r, int x, int k)
{
    if(!rt) rt=++tot; t[rt].val+=k;
    if(l==r) return;
    int mid=(l+r)>>1;
    if(x<=mid) update(t[rt].ls, l, mid, x, k);
    else update(t[rt].rs, mid+1, r, x, k);
}

int merge(int x, int y, int l, int r)
{
    if((!x) || (!y)) return x+y;
    int mid=(l+r)>>1, rt=++tot;
    t[rt].val=t[x].val+t[y].val;
    t[rt].ls=merge(t[x].ls, t[y].ls, l, mid);
    t[rt].rs=merge(t[x].rs, t[y].rs, mid+1, r);
    return rt;
}

LL query(int rt, int l, int r, int x, int y)
{
    if(l>y || r<x || !rt) return 0;
    if(l>=x && r<=y) return t[rt].val;
    int mid=(l+r)>>1;
    return query(t[rt].ls, l, mid, x, y)+query(t[rt].rs, mid+1, r, x, y);
}

void dfs(int x, int fa)
{
    siz[x]=1; dep[x]=dep[fa]+1;
    for(rint i=head[x]; i; i=edge[i].next)
    {
        int to=edge[i].to;
        if(to==fa) continue;
        dfs(to, x);
        siz[x]+=siz[to];
        root[x]=merge(root[x], root[to], 1, n);
    }
    update(root[x], 1, n, dep[x], siz[x]-1);
}

int main()
{
    scanf("%d%d", &n, &q);
    for(rint i=1, x, y; i<n; ++i)
    {
        scanf("%d%d", &x, &y);
        addedge(x, y);
        addedge(y, x);
    }
    dfs(1, 0);
    for(rint i=1; i<=q; ++i)
    {
        int x, k;
        scanf("%d%d", &x, &k);
        LL a=1LL*min(dep[x]-1, k)*(siz[x]-1);
        LL b=query(root[x], 1, n, dep[x]+1, dep[x]+k);
        printf("%lld\n", a+b);
    }
    return 0;
}

发表评论

邮箱地址不会被公开。 必填项已用*标注