[JZOJ5684]【GDSOI2018模拟4.22】Tree

题目描述

这里写图片描述
这里写图片描述

分析

一道简单的虚树加dp题。
显然拉出虚树之后对每条边二分出最优点然后给答案取min即可。
dp的设法是,f[x][012]表示x子树所有点到x的距离的0,1,2次幂。up[f][012]表示x子树外所有点到x。
虚树怎么建呢?
很显然虚树的点就是点集里所有点以及他们按dfn排序后,相邻两个的lca。
为了建出虚树,我们要维护一个深度递增的单调栈。
给出的点按dfn排序后,我们逐个加入虚树。
每次把一个点x和栈顶元素y求lca,然后把单调栈里深度大于lca的全部弹掉,加入lca,加入x。注意lca如果原本就有了就不需要加了。可以看出这个lca实际上就是x和点集的上一个元素的lca。
每次弹栈的时候,我们连边,即栈顶元素和底下一个元素连边,而如果lca的深度在他们之间,则连到lca上。
最后把栈清空一下就行啦。
接下来就是dp随便搞搞。

代码

#include<cstdio> 
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
#define fo(i,j,k) for(i=j;i<=k;i++)
#define fd(i,j,k) for(i=j;i>=k;i--)
#define cmax(a,b) (a=(a>b)?a:b)
#define cmin(a,b) (a=(a<b)?a:b)
typedef long long ll;
typedef double db;
const int N=2e5+5,mo=998244353;
int dfn[N],td,f[N][25],g[N][25],Log[N],dis[N],pd[N],q,n,m,z,y,x,i,j,K,a[N],sta[N],st,lca,d[N],lst,rt;
db med;
ll upf[N][3],F[N][3],go[N][3],C,val,ans,tmp0,tmp1,len,X;
bool cmp(int x,int y) {return dfn[x]<dfn[y];}
int read()
{
    int x=0;
    char ch=getchar();
    while(ch<'0'||ch>'9') ch=getchar();
    while ('0'<=ch&&ch<='9')
    {
        x=x*10+ch-'0';
        ch=getchar();
    }
    return x;
}
int buf[20];
void Print(ll x)
{
    buf[0]=0;
    while (x) buf[++buf[0]]=x%10,x/=10;
    if (!buf[0]) putchar('0');
    while (buf[0]) putchar('0'+buf[buf[0]--]);
    putchar('\n');
}
int tt,b[N],c[N],nxt[N],fst[N];
void cr(int x,int y,int z)
{
    tt++;
    b[tt]=y;
    c[tt]=z;
    nxt[tt]=fst[x];
    fst[x]=tt;
}
int t1,b1[N],c1[N],nxt1[N],fst1[N];
void cr1(int x,int y,int z)
{
    t1++;
    b1[t1]=y;
    c1[t1]=z;
    nxt1[t1]=fst1[x];
    fst1[x]=t1;
}
void dfs(int x,int y)
{
    dfn[x]=++td;
    f[x][0]=y;
    dis[x]=dis[y]+1;
    int i;
    fo(i,1,20) f[x][i]=f[f[x][i-1]][i-1],g[x][i]=g[x][i-1]+g[f[x][i-1]][i-1];
    for(int p=fst[x];p;p=nxt[p])
        if (b[p]!=y)
        {
            g[b[p]][0]=c[p];
            dfs(b[p],x);
        }
}
int Lca(int x,int y)
{
    if (dis[x]<dis[y]) swap(x,y);
    int i;
    fd(i,20,0) if (dis[f[x][i]]>=dis[y]) x=f[x][i];
    if (x==y) return x;
    fd(i,20,0) if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    return f[x][0];
}
int Len(int x,int y)
{
    if (dis[x]<dis[y]) swap(x,y);
    int ret=0,i;
    fd(i,20,0) if (dis[f[x][i]]>=dis[y]) ret+=g[x][i],x=f[x][i];
    return ret;
}
void thr(int x)
{
    d[++d[0]]=x;
    F[x][0]=(pd[x]==q);
    int y;
    for(int p=fst1[x];p;p=nxt1[p])
    {
        y=b1[p];
        thr(y);
        go[y][0]=F[y][0];
        F[x][0]+=F[y][0];
        go[y][1]=F[y][1]+F[y][0]*c1[p];
        F[x][1]+=go[y][1];
        go[y][2]=F[y][2]+2*F[y][1]*c1[p]+F[y][0]*c1[p]*c1[p];
        F[x][2]+=go[y][2];
    }
}

void dp(int x)
{
    for (int p=fst1[x];p;p=nxt1[p])
    {
        y=b1[p];
        upf[y][0]=upf[x][0]+F[x][0]-go[y][0];
        upf[y][1]=upf[x][1]+F[x][1]-go[y][1];
        upf[y][2]=upf[x][2]+F[x][2]-go[y][2]+upf[y][1]*c1[p]*2+upf[y][0]*c1[p]*c1[p];
        upf[y][1]+=upf[y][0]*c1[p];
        dp(y);
    }
}
void solve(int x)
{
    for (int p=fst1[x];p;p=nxt1[p])
    {
        y=b1[p];
        len=c1[p];
        tmp0=upf[y][0];
        tmp1=upf[x][1]+F[x][1]-go[y][1];
        med=(tmp1+len*tmp0-F[y][1])/db(tmp0+F[y][0]);
        z=y;
        X=0;
        fd(i,20,0) if (X+g[z][i]<=med&&dis[f[z][i]]>=dis[x]) X+=g[z][i],z=f[z][i];
        if (dis[f[z][0]]>=dis[x]&&med-db(X)>db(X+g[z][0])-med) X+=g[z][0],z=f[z][0];
        C=F[y][2]+upf[x][2]+F[x][2]-go[y][2];
        val=C+X*X*F[y][0]+2*X*F[y][1]+(len-X)*(len-X)*tmp0+2*(len-X)*tmp1;
        cmin(ans,val);
        solve(b1[p]);
    }
}
int main()
{
    freopen("t8.in","r",stdin);
    freopen("tree.out","w",stdout);
    n=read();m=read();
    fo(i,1,n-1)
    {
        x=read();y=read();z=read();
        cr(x,y,z);
        cr(y,x,z);
    }
    fo(i,1,n) Log[i]=trunc(log(i)/log(2));
    dfs(1,0);
    fo(q,1,m)
    {
        fo(i,1,d[0]) 
        {
            fst1[d[i]]=0;
            fo(j,0,2) upf[d[i]][j]=F[d[i]][j]=0;
        }
        d[0]=0;
        t1=0;
        K=read();
        fo(i,1,K) a[i]=read(),pd[a[i]]=q;
        sort(a+1,a+1+K,cmp);
        sta[st=1]=a[1];
        fo(i,2,K)
        {
            lca=Lca(a[i],sta[st]);
            lst=0;
            while (st&&dis[lca]<dis[sta[st]])
            {
                if (dis[sta[st-1]]>=dis[lca]) cr1(sta[st-1],sta[st],Len(sta[st-1],sta[st]));
                lst=sta[st--];
            } 
            if (lca!=sta[st]) 
            {
                if (dis[lca]<dis[lst]) cr1(lca,lst,Len(lca,lst));
                sta[++st]=lca;
            }
            sta[++st]=a[i];
        }
        while (st>1) cr1(sta[st-1],sta[st],Len(sta[st-1],sta[st])),st--;
        rt=sta[1];
        st--;
        ans=1e18;
        d[0]=0;
        thr(rt);
        dp(rt);
        solve(rt);
        if (K==1) ans=0;
        Print(ans);
    }
}

猜你喜欢

转载自blog.csdn.net/zltjohn/article/details/80191008