bfprt

摘要: BFPRT 算法:1973 年, Blum 、 Floyd 、 Pratt 、 Rivest 、 Tarjan 集体出动,合写了一篇题为 “Time bounds for selection” 的论文,给出了一种在数组中选出第 k 大元素的算法,俗称"中位数之中位数算法"。依靠一种精心设计的 pivot 选取方法,该算法从理论上保证了最坏情形下的线性时间复杂度,打败了平均线性、最坏 O(n^2) 复杂度的传统算法。一群大牛把递归算法的复杂度分析玩弄于股掌之间,构造出了一个当之无愧的来自圣经的算法。

参考: http://en.wikipedia.org/wiki/Median_of_medians

算法步骤:

  1. 将 n 个元素每 5 个一组,分成 n/5 (上界)组。

  2. 取出每一组的中位数,任意排序方法,比如插入排序。

  3. 递归的调用 select 算法查找上一步中所有中位数的中位数,设为 x,偶数个中位数的情况下设定为选取中间小的一个。

  4. 用 x 来分割数组,设小于等于 x 的个数为 m ,大于 x 的个数即为 n-m。

  5. 若 k==m,返回 x;若 k<m,在小于 x 的元素中递归查找第 k 小的元素;若 k>m,在大于 x 的元素中递归查找第 k-m 小的元素。

终止条件:n=1 时,返回的即是 k 小元素。

性能分析:

划分时以5个元素为一组求取中位数,共得到n/5个中位数,再递归求取中位数,复杂度为T(n/5)。

得到的中位数x作为主元进行划分,在n/5个中位数中,主元x大于其中1/2*n/5=n/10的中位数,而每个中位数在其本来的5个数的小组中又大于或等于其中的3个数,所以主元x至少大于所有数中的n/10*3=3/10*n个。同理,主元x至少小于所有数中的3/10*n个。即划分之后,任意一边的长度至少为3/10,在最坏情况下,每次选择都选到了7/10的那一部分,则递归的复杂度为T(7/10*n)。

在每5个数求中位数和划分的函数中,进行若干个次线性的扫描,其时间复杂度为c*n,其中c为常数。其总的时间复杂度满足 T(n) <= T(n/5) + T(7/10*n) + c * n

我们假设T(n)=x*n,其中x不一定是常数(比如x可以为n的倍数,则对应的T(n)=O(n^2))。则有 x*n <= x*n/5 + x*7/10*n + c*n,得到 x<=10*c。于是可以知道x与n无关,T(n)<=10*c*n,为线性时间复杂度算法。而这又是最坏情况下的分析,故BFPRT可以在最坏情况下以线性时间求得n个数中的第k个数。

我参考别人的代码,加了一点自己理解的注释

#include <iostream>
using namespace std;

void swap(int *array, int s,int e)
{
    int temp;
    temp = array[s];
    array[s] = array[e];
    array[e] = temp;
}

// BFPRT: 在数组中寻找第 k 小元素(Top k)
/************************************************************************
FindKthSmallest(Array, k)
pivot = some pivot element of the array.
L = Set of all elements smaller than pivot in Array
R = Set of all elements greater than pivot in Array
if |L| > k FindKthSmalles(L, k)
else if(|L|+1 == k) return pivot
else return FindKthSmallest(R, k-|L|+1)
************************************************************************/

void insertsort(int *array_t, int start, int end)
{
    for (int i = start; i <= end; i++) {
        int inserted_data = array_t[i];
        int j = i;
        for (; j > start && inserted_data < array_t[j - 1]; j--) {
                array_t[j] = array_t[j - 1];
        }
        if (j != i) {
            array_t[j] = inserted_data;        
        }
    }
}

int partition(int *array_t, int low, int high, int pivot_index)
{
    int pivot_value = array_t[pivot_index];
    swap(array_t, low, pivot_index);
    while (low < high) {
        while (low < high && array_t[high] >= pivot_value) {
            high--;
        }
        if (low < high) {
            array_t[low++] = array_t[high];
        }
        while (low < high && array_t[low] <= pivot_value) {
            low++;
        }
        if (low < high) {
            array_t[high--] = array_t[low];
        }
    }
    array_t[low] = pivot_value;
    return low;
}

// 五划分中项:中位数的中位数(the median of medians algorithm)
// Return the kth value
int select(int *array_t, int left, int right, int k)
{
    const int k_group_size = 5;
    int size = right - left + 1;
    if (size <= k_group_size) {
        insertsort(array_t, left, right);
        return array_t[k + left - 1];
    }
    // (right - left) / 2 + left
    const int num_group = (size % k_group_size) > 0 ? (size / k_group_size) + 1 : (size / k_group_size);

    for (int i = 0; i < num_group; i++) {
        int sub_left = left + i * k_group_size;
        int sub_right = sub_left + k_group_size - 1;
        if (sub_right > right) {
            sub_right = right;
        }
        insertsort(array_t, sub_left, sub_right);
        // IMPORTANT !!
        // Place these median in front of array_t, so as to recurse to find the median of median
        int median = sub_left + ((sub_right - sub_left) >> 1);
        swap(array_t, left + i, median);
    }

    // Get the index of median
    int pivot_index = left + ((num_group - 1) >> 1);

    // Recurse to call and place the median on the pivot_index, without care about the median value
    // Because the value of pivot_index must be the median after select function recursive call.
    /*
        //k = (num_group + 1) >> 1 中位数 防止k=0,所以num_group+1
    //int select(int *array_t, int left, int right, int k)
    {
    const int k_group_size = 5;
    int size = right - left + 1;
    if (size <= k_group_size) {
        insertsort(array_t, left, right);
        return array_t[k + left - 1];                //k + left - 1=(num_group+1)/2 + left -1=num_group/2 + left - 1/2 中位数
    }*/
    select(array_t, left, left + num_group - 1, (num_group + 1) >> 1);
    int mid_index = partition(array_t, left, right, pivot_index);
    int _ith = mid_index - left + 1;
    // _ith_element == array_t[_ith]
    if (k == _ith) {
        return array_t[mid_index];
    } else if (k < _ith) {
        return select(array_t, left, mid_index - 1, k);
    } else {
        return select(array_t, mid_index + 1, right, k - _ith);
    }
}

int main()
{
    int k = 8; // 1 <= k <= array.size
    int array[20] = { 11,9,10,1,13,8,15,0,16,2,17,5,14,3,6,18,12,7,19,4 };

    cout << "原数组:";
    for (int i = 0; i < 20; i++)
        cout << array[i] << " ";
    cout << endl;

    // 因为是以 k 为划分,所以还可以求出第 k 小值
    cout << "第 " << k << " 小值为:" << array[select(array, 0, 19, k)] << endl;

    cout << "变换后的数组:";
    for (int i = 0; i < 20; i++)
        cout << array[i] << " ";
    cout << endl;

    return 0;
}


猜你喜欢

转载自blog.csdn.net/wojiuguowei/article/details/84107261