BZOJ5252 [多省联考2018] 林克卡特树

2018 年 04 月 9 日发布.

Description

有一棵 $n$ 个结点的树,每条边有边权。给定 $0\leqslant k\lt n$ ,要求 $k+1$ 条路径(可以仅包含单个点),使得它们不在结点处相交(包括两条路径端点相同的情况),最大化路径上的边权和。

$0\leqslant k \lt n \leqslant 3\times10^5$

Solution

以下的所有 $k$ 均指题目中的 $k+1$ ( $1\leqslant k\leqslant n$ )。

首先,可以 DP 。令 $f_{i,j,0/1/2}$ 表示在 $i$ 的子树中选择 $j$ 条点不相交的路径的最大价值,其中 $0/1/2$ 表示结点 $i$ 不在任意一条路径上 / 是某条路径的端点 / 在某条路径上( $1$ 和 $2$ 可能会重复,不过无所谓的)。

转移的时候考虑 $i$ 是否连向当前子树,跑几遍背包即可。

这样就能拿到 60 分的好成绩啦!

考虑如何优化。

严格的证明(我不会)打表找规律之后我们发现答案关于 $k$ 是凸的。也就是说如果我令 $g_k$ 表示答案的话 $g$ 的差分序列是单调不增的。( $g_k-g_{k-1}\ge g_{k+1}-g_k$ )。

那么就可以上凸优化了。

凸优化是啥?

如果我们忽略掉 $k$ 的限制而只去最大化路径边权和,显然可以 $O(n)$ DP 。这时候,由于答案是凸的,所以它一定是先增后减的,我们求出的就是从增到减中间的那里的答案,也即差分序列从正到负那里的答案。

那么考虑如果我们能把差分序列整体加/减一个值,我们就能把这个“转折点”移动。如果我们把“转折点”恰好移动到 $k$ 这个位置,那我们就能 $O(n)$ DP 出 $k$ 这个位置的答案。

如何把差分序列整体加/减一个值?很简单,只需要把 $g_i$ 变成 $g_i - is$ ,就可以把差分序列整体减 $s$ 了。

那么既然差分序列是单调不增的,那 $s$ 越大,差分序列的“零点”也即原答案的“转折点”越靠左。既然我们要把“转折点”卡到第 $k$ 个位置,就只需要二分 $s$ 即可。

每次二分之后 $O(n)$ DP 出最大的 $g_i - is$ ,也就相当于不限制 k 但是每多选一条路径就会损失 $s$ 的权值( $s$ 有可能是负的);顺便还要求出对应的 $i$ ,即选了多少条路径(由于答案不是严格凸的,可能会有多种选法。这时候求出路径条数最少的即可)。

最后二分出 $s$ 再 DP 一遍,答案加上 $ks$ 即可。

初始二分区间,由于差分不会超过直径长度也即不会超过正边权之和,所以上界为正边权之和即可。差分也不会小于最大正边权的相反数,所以下界为最大正边权的相反数即可。 (这两个性质可以通过求第一个差分和最后一个差分,因为差分是单调不增的)

复杂度 $O(n\log(nW))$ ,其中 $W$ 是边权最大值。

Code

#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>

inline int readInt() {
  int ans = 0, c, f = 1;
  while (!isdigit(c = getchar())) if (c == '-') f *= -1;
  do ans = ans * 10 + c - '0';
  while (isdigit(c = getchar()));
  return ans * f;
}

typedef long long LL;
const int N = 300050;

int n, k;
int pre[N], nxt[N * 2], to[N * 2], val[N * 2], cnt;

inline void addEdge(int x, int y, int v) {
  nxt[cnt] = pre[x];
  val[cnt] = v;
  to[pre[x] = cnt++] = y;
  nxt[cnt] = pre[y];
  val[cnt] = v;
  to[pre[y] = cnt++] = x;
}

LL s, f[N][3], g[N][3];

inline void Check(LL &a, LL &b, LL c, LL d) {
  if (c > a || (c == a && d < b)) a = c, b = d;
}

void dfs(int x, int fa) {
  f[x][0] = 0; f[x][1] = -s; f[x][2] = -s;
  g[x][0] = 0; g[x][1] = 1; g[x][2] = 1;
  for (int i = pre[x]; i != -1; i = nxt[i])
    if (to[i] != fa) {
      int u = to[i];
      dfs(u, x);
      LL tf[3], tg[3];
      memcpy(tf, f[x], sizeof(f[x]));
      memcpy(tg, g[x], sizeof(g[x]));
      LL fu = f[u][0], gu = g[u][0];
      Check(fu, gu, f[u][1], g[u][1]); Check(fu, gu, f[u][2], g[u][2]);
      Check(f[x][2], g[x][2], tf[2] + fu, tg[2] + gu);
      Check(f[x][2], g[x][2], tf[1] + f[u][1] + val[i] + s, tg[1] + g[u][1] - 1);
      Check(f[x][1], g[x][1], tf[1] + fu, tg[1] + gu);
      Check(f[x][1], g[x][1], tf[0] + f[u][1] + val[i], tg[0] + g[u][1]);
      Check(f[x][0], g[x][0], tf[0] + fu, tg[0] + gu);
    }
}

int main() {
  n = readInt();
  k = readInt() + 1;
  memset(pre, -1, sizeof pre);
  LL l = 0, r = 0;
  for (int i = 1, x, y, z; i < n; ++i) {
    x = readInt();
    y = readInt();
    z = readInt();
    addEdge(x, y, z);
    r += std::max(z, 0); l = std::min(l, (LL)-z);
  }
  while (l < r) {
    s = (l + r) >> 1;
    dfs(1, 0);
    LL fu = f[1][0], gu = g[1][0];
    Check(fu, gu, f[1][1], g[1][1]); Check(fu, gu, f[1][2], g[1][2]);
    if (gu <= k) r = s;
    else l = s + 1;
  }
  s = l;
  dfs(1, 0);
  LL fu = f[1][0], gu = g[1][0];
  Check(fu, gu, f[1][1], g[1][1]); Check(fu, gu, f[1][2], g[1][2]);
  printf("%lld\n", fu + s * k);
  return 0;
}