BZOJ2159 Crash的文明世界 【树形dp + Stirling数】
给出一棵$n$个点的树,求对于每个点$i$的$d(i)$值。
$d(i) = \sum_{1\leq x \leq n}^{i \not= x}dist(x, i)^{k}$
数据范围:$1 \leq n \leq 50000, 1\leq k \leq 150$
【题解】
这题非常的神……
首先我们发现,$x^n$能用Stirling数通过一些神奇的方法表示出来……
$x^n=\sum_{1<=k<=n}S(n, k) \times F(x, k)$
其中$S(n, k)$为第二类Stirling数,$S(n, k) = S(n-1, k-1) + k \times S(n-1, k)$
$F(n, k) = n \times (n-1) \times ... \times (n-k+1)$
那么$d(i) = \sum_{j\leq k}S(k, j) \times f(i, j)$
其中$f(i,j) = \sum_{1\leq x \leq n}^{x \not= i} F(dist(i, x), j)$
那么我们就能推出来啦。
但是我们发现,f数组很难通过树形dp求出。考虑到$C(n, k) = \frac{F(n, k)}{k!}$,我们更改f数组表示的内容,改为
$f'(i,j) = \sum_{1\leq x \leq n}^{x \not= i} C(dist(i, x), j)$。
那么我们就有$f(i, j) = f'(i, j) \times j!$,我们可以通过pascal定理转移出f',从而转移出f。
pascal定理:$C(n, k) = C(n-1, k) + C(n-1, k-1)$
然后我们就可以用树形dp统计啦。
注意对于一个节点,树形dp中的f表示的是其以下的子树的和,还需要统计它的兄弟以及父亲上面的值,需要重新计算。
# include <stdio.h> using namespace std; int n, k; int tot=0, head[50010], to[200010], next[200010]; int f[50010][160]; int S[160][160]; int fac[160], ans[50010]; // f[i, j] = sigma(C(dist(i, x), j)) // ans[x] = sigma(S(k, j)*g[x, j]) (1<=j<=k) inline void mad(int &x, int delta) { delta %= 10007; x = (x+delta) % 10007; } inline void msu(int &x, int delta) { delta %= 10007; x = x-delta; x = (x+10007) % 10007; } inline void add(int u, int v) { ++tot; next[tot]=head[u]; head[u]=tot; to[tot]=v; } inline void dp(int x, int fa) { f[x][0]=1; for (int i=head[x]; i; i=next[i]) { if(to[i] == fa) continue; dp(to[i], x); mad(f[x][0], f[to[i]][0]); for (int j=1; j<=k; ++j) mad(f[x][j], f[to[i]][j] + f[to[i]][j-1]); } } int g[160], h[160]; inline void calc(int x, int fa) { for (int i=head[x]; i; i=next[i]) { if(to[i] == fa) continue; g[0] = f[x][0]; msu(g[0], f[to[i]][0]); h[0] = n; for (int j=1; j<=k; ++j) { g[j] = f[x][j], msu(g[j], f[to[i]][j] + f[to[i]][j-1]); h[j] = (f[to[i]][j] + g[j] + g[j-1]) % 10007; } for (int j=0; j<=k; ++j) { f[to[i]][j] = h[j]; mad(ans[to[i]], S[k][j]*fac[j]%10007*h[j]%10007); } calc(to[i], x); } } int main() { int tL, tNOW, tA, tB, tC; scanf("%d%d%d", &n, &k, &tL); scanf("%d%d%d%d", &tNOW, &tA, &tB, &tC); for (int i=1; i<=n-1; ++i) { int u, v; tNOW=(tNOW*tA+tB)%tC; u=i-tNOW%(i<tL?i:tL), v=i+1; //scanf("%d%d", &u, &v); add(u, v); add(v, u); } int s=1; fac[0] = 1; for (int i=1; i<=k; ++i) { s = s * i % 10007; fac[i] = s; } S[0][0] = 1; for (int i=1; i<=k; ++i) for (int j=1; j<=k; ++j) mad(S[i][j], S[i-1][j-1] + S[i-1][j]*j); /* for (int i=1; i<=k; ++i, printf("\n")) for (int j=1; j<=k; ++j) printf("%d ", S[i][j]); */ dp(1, 0); for (int i=0; i<=k; ++i) mad(ans[1], S[k][i]*fac[i]%10007*f[1][i]%10007); calc(1, 0); for (int i=1; i<=n; ++i) printf("%d\n", ans[i]); return 0; }