bzoj 4182 shopping - 树dp - 点分治

考虑一个 O ( n 2 m ) O(n^2m) 的暴力,枚举一个点当根,然后按照dfs序的最后一次访问为阶段做dp,那么每个点要么其子树完全不选(等价于之考虑了Lvis[x]-1),要么就直接在Rvis[x]-1上选x(把x得子树合并起来)。用点分治优化上述过程即可。实际上也可以用dsu on tree做,常数更小。
点分治:

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define gc getchar()
#define N 505
#define M 4005
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
inline int inn()
{
	int x,ch;while((ch=gc)<'0'||ch>'9');
	x=ch^'0';while((ch=gc)>='0'&&ch<='9')
		x=(x<<1)+(x<<3)+(ch^'0');return x;
}
struct edges{
	int to,pre;
}e[N<<1];int h[N],etop,n,m,q[M],w[N],c[N],d[N],t[M],tmp[M],dp[N][M],ans,qv[M];
inline int add_edge(int u,int v) { return e[++etop].to=v,e[etop].pre=h[u],h[u]=etop; }
inline int ins(int *a,int x,int *b)
{
	if(c[x]>m) return memset(b,0,sizeof(int)*(m+1)),0;
	rep(y,0,c[x]-1)
	{
		int fp=1,rp=0;
		for(int i=0,j=y;j<=m;i++,j+=c[x])
		{
			if(fp<=rp&&q[fp]==i-d[x]-1) fp++;
			if(fp<=rp) t[i]=qv[fp]+i*w[x];else t[i]=0;
			int v=a[j]-i*w[x];while(fp<=rp&&qv[rp]<=v) rp--;
			q[++rp]=i,qv[rp]=v,tmp[j]=t[i];
		}
	}
	return memcpy(b,tmp,sizeof(int)*(m+1)),0;
}
int lst[N],Lcnt,sz[N],tms[N],out[N],vis[N],dfc;
int getsz(int x,int fa=0)
{
	lst[++Lcnt]=x,sz[x]=1;
	for(int i=h[x],y;i;i=e[i].pre)
		if((e[i].to^fa)&&!vis[y=e[i].to]) sz[x]+=getsz(y,x);
	return sz[x];
}
inline int getrt(int &x)
{
	for(int i=1,fsz=sz[x],t=fsz;i<=Lcnt;i++)
	{
		int y=lst[i],ysz=fsz-sz[y];
		for(int j=h[y];j;j=e[j].pre)
			if(!vis[e[j].to]&&sz[e[j].to]<sz[y])
				ysz=max(ysz,sz[e[j].to]);
		if(ysz<t) t=ysz,x=y;
	}
	return 0;
}
int gettms(int x,int fa=0)
{
	tms[++dfc]=x;
	for(int i=h[x],y;i;i=e[i].pre)
		if((e[i].to^fa)&&!vis[y=e[i].to]) gettms(y,x);
	return out[x]=dfc;
}
int solve(int x)
{
	Lcnt=0,getsz(x),getrt(x),vis[x]=1,dfc=0,gettms(x);
	memset(dp[dfc+1],0,sizeof(int)*(m+1));
	for(int i=dfc;i;i--)
	{
		ins(dp[i+1],tms[i],dp[i]);int y=out[tms[i]];
		rep(j,0,m) dp[i][j]=max(dp[i][j],dp[y+1][j]);
	}
	rep(i,0,m) ans=max(ans,dp[1][i]);
	for(int i=h[x],y;i;i=e[i].pre) if(!vis[y=e[i].to]) solve(y);
	return 0;
}
int main()
{
	for(int T=inn();T;T--)
	{
		n=inn(),m=inn(),ans=0;rep(i,1,n) w[i]=inn();
		rep(i,1,n) c[i]=inn();rep(i,1,n) d[i]=inn();
		memset(h,0,sizeof(int)*(n+1)),etop=0;int u,v;
		rep(i,1,n-1) u=inn(),v=inn(),add_edge(u,v),add_edge(v,u);
		memset(vis,0,sizeof(int)*(n+1)),solve(1),printf("%d\n",ans);
	}
	return 0;
}

dsu on tree:

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<climits>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define gc getchar()
#define N 505
#define M 4005
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
#define inf (INT_MIN/10)
using namespace std;
inline int inn()
{
	int x,ch;while((ch=gc)<'0'||ch>'9');
	x=ch^'0';while((ch=gc)>='0'&&ch<='9')
		x=(x<<1)+(x<<3)+(ch^'0');return x;
}
struct edges{
	int to,pre;
}e[N<<1];int h[N],etop,n,m,q[M],w[N],c[N],d[N],t[M],tmp[M],dp[N][M],ans,qv[M],son[N],sz[N],infarr[M];
inline int add_edge(int u,int v) { return e[++etop].to=v,e[etop].pre=h[u],h[u]=etop; }
inline int ins(int *a,int x,int *b)
{
	if(c[x]>m) return memcpy(dp[x],infarr,sizeof(int)*(m+1)),0;
	rep(y,0,c[x]-1)
	{
		int fp=1,rp=0;
		for(int i=0,j=y;j<=m;i++,j+=c[x])
		{
			if(fp<=rp&&q[fp]==i-d[x]-1) fp++;
			if(fp<=rp) t[i]=qv[fp]+i*w[x];else t[i]=inf;
			int v=a[j]-i*w[x];while(fp<=rp&&qv[rp]<=v) rp--;
			q[++rp]=i,qv[rp]=v,tmp[j]=t[i];
		}
	}
	return memcpy(b,tmp,sizeof(int)*(m+1)),0;
}
int dfs(int x,int fa)
{
	memcpy(dp[x],dp[fa],sizeof(int)*(m+1)),ins(dp[x],x,dp[x]);
	for(int i=h[x],y;i;i=e[i].pre) if((y=e[i].to)^fa) dfs(y,x);
	rep(i,0,m) dp[fa][i]=max(dp[fa][i],dp[x][i]);return 0;
}
int solve(int x,int fa=0)
{
	sz[x]=1,son[x]=0;
	for(int i=h[x],y;i;i=e[i].pre)
		if((y=e[i].to)^fa)
		{
			sz[x]+=solve(y,x);
			if(sz[y]>sz[son[x]]) son[x]=y;
		}
	memset(dp[x],0,sizeof(int)*(m+1));
	if(son[x]) rep(i,0,m) dp[x][i]=max(dp[son[x]][i],0);
	ins(dp[x],x,dp[x]);
	for(int i=h[x],y;i;i=e[i].pre)
		if((e[i].to^fa)&&(y=e[i].to)!=son[x]) dfs(y,x);
	rep(i,0,m) ans=max(ans,dp[x][i]);return sz[x];
}
int main()
{
	for(int T=inn();T;T--)
	{
		n=inn(),m=inn(),ans=0;rep(i,1,n) w[i]=inn();
		rep(i,1,n) c[i]=inn();rep(i,1,n) d[i]=inn();
		rep(i,1,m) infarr[i]=inf;infarr[0]=0;
		memset(h,0,sizeof(int)*(n+1)),etop=0;int u,v;
		rep(i,1,n-1) u=inn(),v=inn(),add_edge(u,v),add_edge(v,u);
		solve(1),printf("%d\n",ans);
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/Mys_C_K/article/details/82790451