True Liars[扩展域并查集]

传送门

每个点拆成两个,表示好人或坏人

我们合并集合后,发现存在几组对立的集合

也就是说这个集合和与它对立的集合只能选一个

我们用rt1[i] , rt2[i] 表示第i个集合 和 与第i个集合对立的集合

cnt1,cnt2表示该集合好人的个数

用f[i][j]表示到第i个集合,好人为j的方案数

f[i][j]+=f[i-1][j-cnt1[i]] (f[i-1][cnt1[i-1]!=0)

f[i][j]+=f[i-1][j-cnt2[i]](f[i-1][j-cnt2[i]!=0)

同时记录from[i][j] 表示f[i][j]选的是第i组集合的哪一个集合


#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 305*2*2
using namespace std;
int fa[N],n,p1,p2,rt1[N],rt2[N],tot;
int cnt1[N],cnt2[N],f[N/2][N/2],from[N/2][N/2];
int pre[N],vis[N],ans[N];
void dfs(int i,int j){
	if(i==0 && j==0) return;
	if(from[i][j]==1) vis[pre[i]]=1,dfs(i-1,j-cnt1[i]);
	else vis[pre[i]+p1+p2]=1,dfs(i-1,j-cnt2[i]);
}
int find(int x){return x==fa[x]?x:fa[x]=find(fa[x]);}
void init(){
	memset(f,0,sizeof(f));
	memset(from,0,sizeof(from));
	memset(vis,0,sizeof(vis));
	memset(ans,0,sizeof(ans));
	memset(cnt1,0,sizeof(cnt1));
	memset(cnt2,0,sizeof(cnt2));
	memset(rt1,0,sizeof(rt1)); tot=0;
	memset(rt2,0,sizeof(rt2));
}
int main(){
	while(scanf("%d%d%d",&n,&p1,&p2) && (n+p1+p2)){
		init();
		for(int i=1;i<=2*(p1+p2);i++) fa[i]=i;
		for(int i=1;i<=n;i++){
			int x,y,x1,y1; char s[5];
			scanf("%d%d%s",&x,&y,s);
			x1=find(x+p1+p2),y1=find(y+p1+p2),x=find(x),y=find(y);
			if(s[0]=='y') {
				if(x!=y) fa[x]=y;
				if(x1!=y1) fa[x1]=y1;
			}
			if(s[0]=='n'){
				if(x!=y1) fa[x]=y1;
				if(x1!=y) fa[x1]=y;
			}
		}
		for(int i=1;i<=p1+p2;i++){
			if(find(i)==i) rt1[i]=++tot,pre[tot]=i;
		}
		for(int i=1;i<=tot;i++) rt2[pre[i]+p1+p2]=i;
		for(int i=1;i<=p1+p2;i++){
			cnt1[rt1[find(i)]]++ , cnt2[rt2[find(i)]]++;
		}
		f[0][0]=1;
		for(int i=1;i<=tot;i++)
			for(int j=min(cnt1[i],cnt2[i]);j<=p1;j++){ 
				if(f[i-1][j-cnt1[i]]) f[i][j]+=f[i-1][j-cnt1[i]] , from[i][j]=1;
				if(f[i-1][j-cnt2[i]]) f[i][j]+=f[i-1][j-cnt2[i]] , from[i][j]=2; 
			}
		if(f[tot][p1]!=1) {printf("no\n"); continue;}
		dfs(tot,p1);
		for(int i=1;i<=tot;i++) {
			if(vis[pre[i]]) for(int j=1;j<=p1+p2;j++) if(find(j)==pre[i]) ans[j]=1;
			if(vis[pre[i]+p1+p2]) for(int j=1;j<=p1+p2;j++) if(find(j)==pre[i]+p1+p2) ans[j]=1;
		}
		for(int i=1;i<=p1+p2;i++) if(ans[i]) printf("%d\n",i);
		printf("end\n");
	}return 0;
}
扫描二维码关注公众号,回复: 4135834 查看本文章

猜你喜欢

转载自blog.csdn.net/sslz_fsy/article/details/84196532