BZOJ5252 [多省联考2018] 林克卡特树
Description
有一棵 $n$ 个结点的树,每条边有边权。给定 $0\leqslant k\lt n$ ,要求 $k+1$ 条路径(可以仅包含单个点),使得它们不在结点处相交(包括两条路径端点相同的情况),最大化路径上的边权和。
$0\leqslant k \lt n \leqslant 3\times10^5$
Solution
以下的所有 $k$ 均指题目中的 $k+1$ ( $1\leqslant k\leqslant n$ )。
首先,可以 DP 。令 $f_{i,j,0/1/2}$ 表示在 $i$ 的子树中选择 $j$ 条点不相交的路径的最大价值,其中 $0/1/2$ 表示结点 $i$ 不在任意一条路径上 / 是某条路径的端点 / 在某条路径上( $1$ 和 $2$ 可能会重复,不过无所谓的)。
转移的时候考虑 $i$ 是否连向当前子树,跑几遍背包即可。
这样就能拿到 60 分的好成绩啦!
考虑如何优化。
严格的证明(我不会)打表找规律之后我们发现答案关于 $k$ 是凸的。也就是说如果我令 $g_k$ 表示答案的话 $g$ 的差分序列是单调不增的。( $g_k-g_{k-1}\ge g_{k+1}-g_k$ )。
那么就可以上凸优化了。
凸优化是啥?
如果我们忽略掉 $k$ 的限制而只去最大化路径边权和,显然可以 $O(n)$ DP 。这时候,由于答案是凸的,所以它一定是先增后减的,我们求出的就是从增到减中间的那里的答案,也即差分序列从正到负那里的答案。
那么考虑如果我们能把差分序列整体加/减一个值,我们就能把这个“转折点”移动。如果我们把“转折点”恰好移动到 $k$ 这个位置,那我们就能 $O(n)$ DP 出 $k$ 这个位置的答案。
如何把差分序列整体加/减一个值?很简单,只需要把 $g_i$ 变成 $g_i - is$ ,就可以把差分序列整体减 $s$ 了。
那么既然差分序列是单调不增的,那 $s$ 越大,差分序列的“零点”也即原答案的“转折点”越靠左。既然我们要把“转折点”卡到第 $k$ 个位置,就只需要二分 $s$ 即可。
每次二分之后 $O(n)$ DP 出最大的 $g_i - is$ ,也就相当于不限制 k 但是每多选一条路径就会损失 $s$ 的权值( $s$ 有可能是负的);顺便还要求出对应的 $i$ ,即选了多少条路径(由于答案不是严格凸的,可能会有多种选法。这时候求出路径条数最少的即可)。
最后二分出 $s$ 再 DP 一遍,答案加上 $ks$ 即可。
初始二分区间,由于差分不会超过直径长度也即不会超过正边权之和,所以上界为正边权之和即可。差分也不会小于最大正边权的相反数,所以下界为最大正边权的相反数即可。 (这两个性质可以通过求第一个差分和最后一个差分,因为差分是单调不增的)
复杂度 $O(n\log(nW))$ ,其中 $W$ 是边权最大值。
Code
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
inline int readInt() {
int ans = 0, c, f = 1;
while (!isdigit(c = getchar())) if (c == '-') f *= -1;
do ans = ans * 10 + c - '0';
while (isdigit(c = getchar()));
return ans * f;
}
typedef long long LL;
const int N = 300050;
int n, k;
int pre[N], nxt[N * 2], to[N * 2], val[N * 2], cnt;
inline void addEdge(int x, int y, int v) {
[cnt] = pre[x];
nxt[cnt] = v;
val[pre[x] = cnt++] = y;
to[cnt] = pre[y];
nxt[cnt] = v;
val[pre[y] = cnt++] = x;
to}
, f[N][3], g[N][3];
LL s
inline void Check(LL &a, LL &b, LL c, LL d) {
if (c > a || (c == a && d < b)) a = c, b = d;
}
void dfs(int x, int fa) {
[x][0] = 0; f[x][1] = -s; f[x][2] = -s;
f[x][0] = 0; g[x][1] = 1; g[x][2] = 1;
gfor (int i = pre[x]; i != -1; i = nxt[i])
if (to[i] != fa) {
int u = to[i];
(u, x);
dfs[3], tg[3];
LL tf(tf, f[x], sizeof(f[x]));
memcpy(tg, g[x], sizeof(g[x]));
memcpy= f[u][0], gu = g[u][0];
LL fu (fu, gu, f[u][1], g[u][1]); Check(fu, gu, f[u][2], g[u][2]);
Check(f[x][2], g[x][2], tf[2] + fu, tg[2] + gu);
Check(f[x][2], g[x][2], tf[1] + f[u][1] + val[i] + s, tg[1] + g[u][1] - 1);
Check(f[x][1], g[x][1], tf[1] + fu, tg[1] + gu);
Check(f[x][1], g[x][1], tf[0] + f[u][1] + val[i], tg[0] + g[u][1]);
Check(f[x][0], g[x][0], tf[0] + fu, tg[0] + gu);
Check}
}
int main() {
= readInt();
n = readInt() + 1;
k (pre, -1, sizeof pre);
memset= 0, r = 0;
LL l for (int i = 1, x, y, z; i < n; ++i) {
= readInt();
x = readInt();
y = readInt();
z (x, y, z);
addEdge+= std::max(z, 0); l = std::min(l, (LL)-z);
r }
while (l < r) {
= (l + r) >> 1;
s (1, 0);
dfs= f[1][0], gu = g[1][0];
LL fu (fu, gu, f[1][1], g[1][1]); Check(fu, gu, f[1][2], g[1][2]);
Checkif (gu <= k) r = s;
else l = s + 1;
}
= l;
s (1, 0);
dfs= f[1][0], gu = g[1][0];
LL fu (fu, gu, f[1][1], g[1][1]); Check(fu, gu, f[1][2], g[1][2]);
Check("%lld\n", fu + s * k);
printfreturn 0;
}