【Codeforces Gym-101933-E-Explosion Exploit】记忆化搜索

Explosion Exploit

题目链接:

https://codeforces.com/gym/101933/problem/E

Description

在这里插入图片描述

Input

在这里插入图片描述

Output

在这里插入图片描述

Sample Input

1 2 2
2
1 1

Sample Output

0.33333333

题意

己方有n个士兵,敌方有m个士兵,每个士兵有h[i]的生命值,现在造成d次伤害,每次随机对一名存活的士兵造成一点伤害,问可以把敌方所有士兵杀死的概率是多少,

题解

首先考虑最暴力的做法,因为 n 5 , m 5 , d 100 , h [ i ] 6 n\leq5,m\leq5,d\leq100,h[i]\leq6 ,所以己方和敌方血量状态一共有 6 5 6^5 种,所以定义

dp方程 d p [ 105 ] [ 6 5 ] [ 6 5 ] dp[105][6^5][6^5] 表示剩余伤害为i,己方血量状态为j,敌方血量状态为k时消灭敌方的概率。

转移的时候可以根据状态计算出总人数,和敌方存活状态,直接转移即可,注意如果敌方人数=0那么概率为1,如果d=0时敌方人数仍不为0那么概率为0。之后暴力转移即可。但是这里 105 6 5 6 5 = 6 e 9 105* 6^5*6^5 = 6e9 ,显然是开不下的,所以我们需要优化这个DP数组。我们发现对于双方血量状态我们是不需要顺序的,也就是每个人都是等价的。所以我们可以强制让这个状态是有序的,这样我们发现己方的状态只有大概500种,所以dp数组开成 d p [ 105 ] [ 500 ] [ 500 ] dp[105][500][500] 即可,之后记忆化搜索即可解决这个问题,需要注意的是每次枚举血量-1时仍然要强行排序让状态时有序的。

代码

#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<string.h>
#include<map>
#include<vector>
#include<set>
#include<queue>
#include<time.h>
#include<math.h>
using namespace std;

//***********************IO**********************************
namespace fastIO
{
    #define BUF_SIZE 100000
    #define OUT_SIZE 100000
    bool IOerror=0;
    inline char nc()
    {
        static char buf[BUF_SIZE],*p1=buf+BUF_SIZE,*pend=buf+BUF_SIZE;
        if (p1==pend)
        {
            p1=buf;
            pend=buf+fread(buf,1,BUF_SIZE,stdin);
            if (pend==p1)
            {
                IOerror=1;
                return -1;
            }
        }
        return *p1++;
    }
    inline bool blank(char ch)
    {
        return ch==' '|ch=='\n'||ch=='\r'||ch=='\t';
    }
    inline void read(int &x)
    {
        bool sign=0;
        char ch=nc();
        x=0;
        for (; blank(ch); ch=nc());
        if (IOerror)return;
        if (ch=='-')sign=1,ch=nc();
        for (; ch>='0'&&ch<='9'; ch=nc())x=x*10+ch-'0';
        if (sign)x=-x;
    }
    #undef OUT_SIZE
    #undef BUF_SIZE
};
using namespace fastIO;
//************************************************************************

#define ok cout<<"OK"<<endl;
#define dbg(x) cout<<#x<<" = "<<x<<endl;
#define dbg2(x1,x2) cout<<#x1<<" = "<<x1<<" "<<#x2<<" = "<<x2<<endl;
#define dbg3(x1,x2,x3) cout<<#x1<<" = "<<x1<<" "<<#x2<<" = "<<x2<<" "<<#x3<<" = "<<x3<<endl;
#define print(a,n) for(int i=1;i<=n;i++) cout<<a[i]<<" ";cout<<endl;
#define pb push_back
#define Fi first
#define Se second
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pil pair<int,ll>
#define pll pair<ll,ll>

const double eps = 1e-8;
const double PI = acos(-1.0);
const int Mod = 1000000007;
const int INF = 0x3f3f3f3f;
const ll LL_INF = 0x3f3f3f3f3f3f3f3f;
const int maxn = 2e5+10;
map<vector<int> ,int> code;
vector<int> v[500];
double dp[500][500][105];
double dfs(int sta,int stb,int d)
{
    if(dp[sta][stb][d]>-0.5) return dp[sta][stb][d];
    vector<int> a = v[sta];
    vector<int> b = v[stb];
    int sz1=0,sz2=0;
    for(int i=0;i<5;i++) if(a[i]>0) sz1++;
    for(int i=0;i<5;i++) if(b[i]>0) sz2++;
    if(sz2==0) return dp[sta][stb][d]=1;
    if(d==0) return dp[sta][stb][d]=0;
    double ans=0;
    for(int i=0;i<5;i++)
    {
        if(a[i]>0)
        {
            a[i]--;
            vector<int> tt=a;
            sort(tt.begin(),tt.end());
            int sta=code[tt];
            ans=ans+1.0*dfs(sta,stb,d-1)/(sz1+sz2);
            a[i]++;
        }
    }
    for(int i=0;i<5;i++)
    {
        if(b[i]>0)
        {
            b[i]--;
            vector<int> tt=b;
            sort(tt.begin(),tt.end());
            int stb=code[tt];
            ans=ans+1.0*dfs(sta,stb,d-1)/(sz1+sz2);
            b[i]++;
        }
    }
    return dp[sta][stb][d]=ans;
}
int main()
{
    //freopen(".in","r",stdin);
    int cnt=0;
    vector<int> tt;
    for(int i=0;i<500;i++) for(int j=0;j<500;j++) for(int k=0;k<105;k++) dp[i][j][k]=-1;
    for(int i=0;i<=6;i++)
    {
        tt.push_back(i);
        for(int j=i;j<=6;j++)
        {
            tt.push_back(j);
            for(int k=j;k<=6;k++)
            {
                tt.push_back(k);
                for(int l=k;l<=6;l++)
                {
                    tt.push_back(l);
                    for(int m=l;m<=6;m++)
                    {
                        tt.push_back(m);
                        v[cnt]=tt;
                        code[tt]=cnt++;
                        tt.pop_back();
                    }
                    tt.pop_back();
                }
                tt.pop_back();
            }
            tt.pop_back();
        }
        tt.pop_back();
    }
    int n,m,d;
    scanf("%d%d%d",&n,&m,&d);
    vector<int> a(5),b(5);
    for(int i=0;i<5;i++) a[i]=b[i]=0;
    for(int i=0;i<n;i++) scanf("%d",&a[i]);
    for(int i=0;i<m;i++) scanf("%d",&b[i]);
    sort(a.begin(),a.end());
    sort(b.begin(),b.end());
    int sta=code[a],stb=code[b];
    double ans=dfs(sta,stb,d);
    printf("%.10f\n",ans);
    return 0;
}
发布了299 篇原创文章 · 获赞 117 · 访问量 6万+

猜你喜欢

转载自blog.csdn.net/qq_38891827/article/details/101442562