问题描述

给定一棵 $n$ 个结点的树,你从点 $x$ 出发,每次等概率随机选择一条与所在点相邻的边走过去。

有 $q$ 次询问,每次询问给定一个集合 $S$,求如果从 $x$ 出发一直随机游走,直到点集 $S$ 中所有点都至少经过一次的话,期望游走几步。

特别地,点 $x$(即起点)视为一开始就被经过了一次。

答案对 $998244353$ 取模。

输入

第一行三个正整数 $n,q,x$。

接下来 $n-1$ 行,每行两个正整数 $(u,v)$ 描述一条树边。

接下来 $q$ 行,每行第一个数 $k$ 表示集合大小,接下来 $k$ 个互不相同的数表示集合 $S$。

输出

输出 $q$ 行,每行一个非负整数表示答案。

思路

首先要知道 $min-max$ 容斥,然后根据 $min-max$ 容斥的常规套路就可以把原问题转换为求 $f(root)$,表示从 $root$ 开始到集合 $T$ 中元素的最小期望步数。

考虑求 $f(x)$。路径上概率期望的套路一般是根据期望的递推关系,对于每一个点列一个方程,最后解方程组得出期望。这题也可以按照这种方法,得出方程:
$$f[x]=\frac{1}{deg[x]} \sum f[to] +1$$
用 $O(n^3)$ 解这个方程组,然后套 $min-max$ 容斥的复杂度是 $O(n^32^n+n^2q)$,但这个复杂度是满足不了要求的。我们可以考虑优化如何解这个方程组。下面将式子进行化简:
$$f[x]=\frac{1}{deg[x]} \sum_{to \in ch[x]} f[to] + \frac{1}{deg[x]} f[fa]+ 1 $$
设 $$A_x=\frac{1}{deg[x]}, B_x=\frac{\sum_{to \in ch[x]} f[to]} {deg[x] }+1$$那么就可以将上式写成:
$$f[x]=A_x * f[fa]+B_x$$

考虑:
$$\sum f[to]=\sum A_{to} f[x]+\sum B_{to}$$
将 $\sum f[to]$ 代入原先式子,得到:
$$f[x]=\frac{f[fa]}{deg[x]}+\frac{\sum A_{to}*f[x]+\sum B_{to}}{deg[x]}+1$$
$$(deg[x]-\sum A_{to}) f[x]=f[fa]+\sum B_{to}+deg[x]$$
$$f[x]=\frac{f[fa]}{deg[x]-\sum A_{to}} + \frac{\sum B_{to}+deg[x]}{deg[x]-\sum A_{to}}$$

得到:
$$A_x=\frac{1}{deg[x]-\sum A_{to}}, B_{x}=\frac{\sum B_{to}+deg[x]}{deg[x]-\sum A_{to}}$$

然后我们设属于集合 $T$ 的点 $A_x=0, B_x=0$,这样我们就可以从下往上递推出 $A_x, B_x$,然后因为 $root$ 没有父亲,所以我们要求的 $f[root]$ 就是 $B_{root}$。

先预处理出所有集合 $T$ 的 $f[root]$,然后对于每一个询问,枚举子集,用 $min-max$ 容斥得出答案。这样总复杂度就是 $O(n*2^n+qn^2)$。

代码

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#define MAXN
#define p 998244353
#define INF 0x3f3f3f3f
#define rint register int
#define LL long long
#define LD long double
using namespace std;

int n, q, x, cnt, head[20], deg[20], a[20], b[20], f[1<<19], siz[1<<19];

struct Edge {int next, to;} edge[40];

void addedge(int from, int to)
{
    edge[++cnt].next=head[from];
    edge[cnt].to=to;
    deg[to]++;
    head[from]=cnt;
}

int ksm(int x, int y)
{
    int sum=1;
    while(y)
    {
        if(y&1) sum=(1LL*x*sum)%p;
        x=(1LL*x*x)%p; y>>=1;
    }
    return sum;
}

void dp(int sta, int x, int fa)
{

    int suma=0, sumb=0;
    if((sta>>(x-1))&1) {a[x]=b[x]=0; return;}
    for(rint i=head[x]; i; i=edge[i].next)
    {
        int to=edge[i].to;
        if(to==fa) continue;
        dp(sta, to, x);
        suma=(suma+a[to])%p;
        sumb=(sumb+b[to])%p;
    }
    int temp=ksm((deg[x]+p-suma)%p, p-2);
    a[x]=1LL*temp;
    b[x]=1LL*(sumb+deg[x])%p*temp%p;
}

int main()
{
    scanf("%d%d%d", &n, &q, &x);
    for(rint i=1, x, y; i<n; ++i)
    {
        scanf("%d%d", &x, &y);
        addedge(x, y);
        addedge(y, x);
    }
    for(rint sta=0; sta<(1<<n); ++sta)
    {
        dp(sta, x, 0);
        f[sta]=b[x];
        siz[sta]=siz[sta>>1]+(sta&1);
    }
    while(q--)
    {
        int sta=0, ans=0, k;
        scanf("%d", &k);
        for(rint i=1, x; i<=k; ++i)
        {
            scanf("%d", &x);
            sta|=(1<<(x-1));
        }
        for(rint i=sta; i; i=(i-1)&sta)
        {
            if(siz[i]&1) ans=(ans+f[i])%p;
            else ans=(ans+p-f[i])%p;
        }
        printf("%d\n", ans);
    }
    return 0;
}

发表评论

邮箱地址不会被公开。 必填项已用*标注