题意
分析
我们首先可以知道a的每个元素都能影响到一个区间,且区间是包含或者不相交的
然后我们要找包括a中某一个位置的sumb,考虑到a有正负,我们的b也要维护和的最大最小值
包含一个位置的区间里面的一个子区间和最大最小值我们要想办法求
考虑把这个位置的左右两边分开,进行从小到大分治,然后再合并就可以了,记录左边和右边和区间的最大最小一段,和区间的和
代码
#include <bits/stdc++.h>
#define mp make_pair
#define pb push_back
#define inf 0x3f3f3f3f3f3f3f3f
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pii;
const ll N = 3e6 + 10;
inline ll rd()
{
char ch=getchar(); ll p=0; ll f=1;
while(ch<'0' || ch>'9'){if(ch=='-') f=-1; ch=getchar();}
while(ch>='0' && ch<='9'){p=p*10+(ch-'0'); ch=getchar();}
return p*f;
}
ll a[N],b[N]; ll n;
struct node
{
ll lmin,rmin,lmax,rmax,bmin,bmax,s;
}ii;
ll mn[N<<2],mnp[N<<2]; ll lc[N<<2],rc[N<<2]; ll rt,tot;
void build(ll &u,ll l,ll r)
{
if(!u) u=++tot;
if(l==r){mn[u] = a[l]; mnp[u] = l; return ;}
ll mid=(l+r)>>1;
build(lc[u],l,mid);
build(rc[u],mid+1,r);
if(mn[lc[u]] < mn[rc[u]]) mn[u] = mn[lc[u]] , mnp[u] = mnp[lc[u]];
else mn[u] = mn[rc[u]],mnp[u] = mnp[rc[u]];
}
pair<ll,ll> qry(ll u,ll L,ll R,ll l,ll r)
{
if(L==l && R==r) return mp(mn[u],mnp[u]);
ll mid=(L+R)>>1;
if(r<=mid) return qry(lc[u],L,mid,l,r);
else if(l>mid) return qry(rc[u],mid+1,R,l,r);
else return min( qry(lc[u],L,mid,l,mid) , qry(rc[u],mid+1,R,mid+1,r) );
}
ll ans = 0;
node find(ll l,ll r)
{
if(l>r) return ii;
pair<ll,ll> x = qry(rt,1,n,l,r);
ll mid = x.second;
node lch = find(l,mid-1);
node rch = find(mid+1,r);
node nw;
if(l==r) nw = node{b[l],b[l],b[l],b[l],b[l],b[l],b[l]};
else
{
if(lch.bmin == inf)
{
nw.bmin = min(b[mid] , b[mid] + rch.lmin);
nw.bmax = max(b[mid] , b[mid] + rch.lmax);
nw.lmin = min(b[mid] , b[mid] + rch.lmin);
nw.lmax = max(b[mid] , b[mid] + rch.lmax);
nw.rmin = min(rch.rmin , rch.s + b[mid]);
nw.rmax = max(rch.rmax , rch.s + b[mid]);
nw.s = b[mid] + rch.s;
}
else if(rch.bmin == inf)
{
nw.bmin = min(b[mid] , b[mid] + lch.rmin);
nw.bmax = max(b[mid] , b[mid] + lch.rmax);
nw.lmin = min(lch.lmin , b[mid] + lch.s);
nw.lmax = max(lch.lmax , b[mid] + lch.s);
nw.rmin = min(b[mid] , b[mid] + lch.rmin);
nw.rmax = max(b[mid] , b[mid] + lch.rmax);
nw.s = b[mid] + lch.s;
}
else
{
nw.bmin = min(b[mid] , min(b[mid] + rch.lmin , min(b[mid] + lch.rmin , b[mid] + rch.lmin + lch.rmin ) ) );
nw.bmax = max(b[mid] , max(b[mid] + rch.lmax , max(b[mid] + lch.rmax , b[mid] + rch.lmax + lch.rmax ) ) );
nw.lmin = min(lch.lmin , min(lch.s + b[mid] , lch.s + b[mid] + rch.lmin) );
nw.lmax = max(lch.lmax , max(lch.s + b[mid] , lch.s + b[mid] + rch.lmax) );
nw.rmin = min(rch.rmin , min(rch.s + b[mid] , rch.s + b[mid] + lch.rmin) );
nw.rmax = max(rch.rmax , max(rch.s + b[mid] , rch.s + b[mid] + lch.rmax) );
nw.s = lch.s + rch.s + b[mid];
}
}
if(x.first < 0) ans = max(ans , nw.bmin * x.first);
else if(x.first > 0) ans = max(ans , nw.bmax * x.first);
return nw;
}
int main()
{
n = rd(); for(ll i=1;i<=n;i++) a[i] = rd();
for(ll i=1;i<=n;i++) b[i] = rd();
rt=tot=0; build(rt,1,n);
ans = 0; ii = {0,0,0,0,inf,-inf,0};find(1,n);
return printf("%lld\n",ans),0;
}