最近点对问题

    最近遇到一个最近点对的问题,但是用老师讲的方法去做超时了,只好采用网上的方法去做。可是在本地生成数据测试的结果却显示老师讲的方法更快,让我很是不解,所以记录下来,等以后解决。
    题目链接https://www.nowcoder.com/acm/contest/59/E
题目描述

给你一个长为n的序列a。定义f(i,j)=(i-j)2+g(i,j)2
g是这样的一个函数
这里写图片描述
求最小的f(i,j)的值,i!=j

输入描述:

第一行一个数n
之后一行n个数表示序列a

输出描述:

输出一行一个数表示答案

示例1

输入
4
1 0 0 -1
输出
1

备注:

对于100%的数据,2 <= n <= 100000 , |ai| <= 10000

解题思路:
    把下标i当成x轴坐标,前i个数字之和当成y轴坐标,就可以把这道题看成最近点对问题(准确来说是最近点对距离的平方)。
    解决最近点对用的是分治的思想,首先将所有点按x轴坐标大小排序(本题已经有序),然后根据x轴坐标的中位数将点分成左右两部分(即对半分),分别求最近点对的距离。如果点的数量小于等于3,就枚举求解。(即递归出口)
    于是最后的答案就是左边的最近点对距离dl、右边最近点对距离dr和左边点和右边点组成点对的最小距离dm。我们可以先得到dl和dr中的最小值d,然后将(x轴坐标中位数-d, x轴坐标中位数+d)范围内的点纳入左右之间产生最近点对的考虑范围。
    两种方法(一种在OJ上TLE,另一种可以AC)在如何处理左右之间形成的点对的做法上有所不同。具体见下面的代码。


AC代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
#include <map>
#include <algorithm>
#include <iomanip>

using namespace std;

const int MAXN = 100003;
const long long INF = (1LL<<60) - 1;
int sum[MAXN], idx[MAXN];

bool cmp(int i, int j)
{
    return sum[i] < sum[j];
}

long long dist(int i, int j)
{
    int di = j - i;
    long long dx = sum[j] - sum[i];

    return di*di + dx*dx;
}

long long findSPP(long long lhs, long long rhs)
{
    long long res;
    if(rhs - lhs <=2)
    {
        res = INF;
        for(int i=lhs; i<rhs; ++i)
        {
            for(int j=i+1; j<=rhs; ++j)
            {
                res = min(res, dist(i, j));
            }
        }
    }else{
        long long mid = (lhs+rhs) / 2;
        long long dl = findSPP(lhs, mid);
        long long dr = findSPP(mid+1, rhs);
        res = min(dl, dr);

        long long d = sqrt(res);
        int k = 0;
        for(int i=max(mid-d, lhs); i<=min(mid+d, rhs); ++i)     
            idx[k++] = i;      
        sort(idx, idx+k, cmp);

        for(int i=0; i<k; ++i)
        {
            for(int j=i+1; j<k; ++j)
            {
                if(sum[idx[j]]-sum[idx[i]] >= d) break;   //因为按y值大小排过序
                res = min(res, dist(idx[j], idx[i]));
            }
        }
    }

    return res;
}

int main()
{
    int n;
    while(cin>>n)
    {
        memset(sum, 0, sizeof(sum));
        for(int i=1; i<=n; ++i)
        {
            scanf("%d", &sum[i]);
            sum[i] += sum[i-1];
        }
        long long ans = findSPP(1, n);
        cout<<ans<<endl;
    }
    return 0;
}

TLE代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
#include <map>
#include <algorithm>
#include <iomanip>

using namespace std;

const int MAXN = 100003;
const long long INF = (1LL<<63) - 1;
int sum[MAXN];

long long dist(int i, int j)
{
    int di = j - i;
    long long dx = sum[j] - sum[i];

    return di*di + dx*dx;
}

long long findSPP(long long lhs, long long rhs)
{
    long long res;
    if(rhs - lhs <=2)
    {
        res = INF;
        for(int i=lhs; i<rhs; ++i)
        {
            for(int j=i+1; j<=rhs; ++j)
            {
                res = min(res, dist(i, j));
            }
        }
    }else{
        long long mid = (lhs+rhs) / 2;
        long long dl = findSPP(lhs, mid);
        long long dr = findSPP(mid+1, rhs);
        res = min(dl, dr);

        long long d = sqrt(res);
        for(int i=max(mid-d, lhs); i<=mid; ++i)    //左边需要考虑的点
        {
            for(int j=mid+1; j<=min(mid+d, rhs); ++j)     //右边需要考虑的点
            {
                if(abs(sum[j]-sum[i]) < d)
                {
                    res = min(res, dist(i, j));    //根据鸽巢原理可证明对于每个i,执行不会超过6次
                }
            }
        }
    }

    return res;
}

int main()
{
    int n;
    while(cin>>n)
    {
        memset(sum, 0, sizeof(sum));
        for(int i=1; i<=n; ++i)
        {
            scanf("%d", &sum[i]);
            sum[i] += sum[i-1];
        }
        long long ans = findSPP(1, n);
        cout<<ans<<endl;
    }
    return 0;
}

    虽然后一种方法没有在OJ上通过,但本地测试却显示后一种方法更快,下面附上测试代码,如果有人知道是为什么,请不吝赐教!
本地测试代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
#include <map>
#include <algorithm>
#include <iomanip>
#include <ctime>

using namespace std;

const int MAXN = 100003;
const long long INF = (1LL<<60) - 1;
int sum[MAXN], idx[MAXN];

bool cmp(int i, int j)
{
    return sum[i] < sum[j];
}

long long dist(int i, int j)
{
    int di = j - i;
    long long dx = sum[j] - sum[i];

    return di*di + dx*dx;
}

long long tle(long long lhs, long long rhs)
{
    long long res;
    if(rhs - lhs <=2)
    {
        res = INF;
        for(int i=lhs; i<rhs; ++i)
        {
            for(int j=i+1; j<=rhs; ++j)
            {
                res = min(res, dist(i, j));
            }
        }
    }else{
        long long mid = (lhs+rhs) / 2;
        long long dl = tle(lhs, mid);
        long long dr = tle(mid+1, rhs);
        res = min(dl, dr);

        long long d = sqrt(res);
        for(int i=max(mid-d, lhs); i<=mid; ++i)
        {
            int c = 0;
            for(int j=mid+1; j<=min(mid+d, rhs); ++j)
            {
                if(abs(sum[j]-sum[i]) < d)
                {
                    res = min(res, dist(i, j));
                    ++c;
                }
            }
            if(c > 6) cout<<"理论有误!"<<endl;
        }
    }

    return res;
}

long long ac(long long lhs, long long rhs)
{
    long long res;
    if(rhs - lhs <=2)
    {
        res = INF;
        for(int i=lhs; i<rhs; ++i)
        {
            for(int j=i+1; j<=rhs; ++j)
            {
                res = min(res, dist(i, j));
            }
        }
    }else{
        long long mid = (lhs+rhs) / 2;
        long long dl = ac(lhs, mid);
        long long dr = ac(mid+1, rhs);
        res = min(dl, dr);

        long long d = sqrt(res);
        int k = 0;
        for(int i=max(mid-d, lhs); i<=min(mid+d, rhs); ++i)     
            idx[k++] = i;       
        sort(idx, idx+k, cmp);

        for(int i=0; i<k; ++i)
        {
            for(int j=i+1; j<k; ++j)
            {
                if(sum[idx[j]]-sum[idx[i]] >= d) break;
                res = min(res, dist(idx[j], idx[i]));
            }
        }
    }

    return res;
}

int main()
{
    int n;
    srand(time(NULL));
    for(int i=0; i<100; ++i)
    {
        cout<<"第"<<i<<"轮测试"<<'\t';
        n = rand()*1.0/RAND_MAX * 100000;
        cout<<"n = "<<n<<endl;

        memset(sum, 0, sizeof(sum));
        for(int i=1; i<=n; ++i)
        {
            sum[i] = rand()*1.0/RAND_MAX * 10000 - (rand()*1.0/RAND_MAX * 10000);
            sum[i] += sum[i-1];
        }
        clock_t start, end;
        double time1, time2;

        start = clock();
        long long ans1 = ac(1, n);
        end = clock();
        time1 = static_cast<double>(end - start) / CLOCKS_PER_SEC;

        start = clock();
        long long ans2 = tle(1, n);
        end = clock();
        time2 = static_cast<double>(end - start) / CLOCKS_PER_SEC;

        if(ans1 != ans2) cout<<"答案不一致\t"<<ans1<<'\t'<<ans2<<endl;

        if(time1 > time2) cout<<"ac程序时间更长\t"<<"ac time: "<<time1<<"\ttle time: "<<time2<<endl;
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/u011008379/article/details/79198757