摘要: BFPRT 算法:1973 年, Blum 、 Floyd 、 Pratt 、 Rivest 、 Tarjan 集体出动,合写了一篇题为 “Time bounds for selection” 的论文,给出了一种在数组中选出第 k 大元素的算法,俗称"中位数之中位数算法"。依靠一种精心设计的 pivot 选取方法,该算法从理论上保证了最坏情形下的线性时间复杂度,打败了平均线性、最坏 O(n^2) 复杂度的传统算法。一群大牛把递归算法的复杂度分析玩弄于股掌之间,构造出了一个当之无愧的来自圣经的算法。
参考: http://en.wikipedia.org/wiki/Median_of_medians
算法步骤:
-
将 n 个元素每 5 个一组,分成 n/5 (上界)组。
-
取出每一组的中位数,任意排序方法,比如插入排序。
-
递归的调用 select 算法查找上一步中所有中位数的中位数,设为 x,偶数个中位数的情况下设定为选取中间小的一个。
-
用 x 来分割数组,设小于等于 x 的个数为 m ,大于 x 的个数即为 n-m。
-
若 k==m,返回 x;若 k<m,在小于 x 的元素中递归查找第 k 小的元素;若 k>m,在大于 x 的元素中递归查找第 k-m 小的元素。
终止条件:n=1 时,返回的即是 k 小元素。
性能分析:
划分时以5个元素为一组求取中位数,共得到n/5个中位数,再递归求取中位数,复杂度为T(n/5)。
得到的中位数x作为主元进行划分,在n/5个中位数中,主元x大于其中1/2*n/5=n/10的中位数,而每个中位数在其本来的5个数的小组中又大于或等于其中的3个数,所以主元x至少大于所有数中的n/10*3=3/10*n个。同理,主元x至少小于所有数中的3/10*n个。即划分之后,任意一边的长度至少为3/10,在最坏情况下,每次选择都选到了7/10的那一部分,则递归的复杂度为T(7/10*n)。
在每5个数求中位数和划分的函数中,进行若干个次线性的扫描,其时间复杂度为c*n,其中c为常数。其总的时间复杂度满足 T(n) <= T(n/5) + T(7/10*n) + c * n。
我们假设T(n)=x*n,其中x不一定是常数(比如x可以为n的倍数,则对应的T(n)=O(n^2))。则有 x*n <= x*n/5 + x*7/10*n + c*n,得到 x<=10*c。于是可以知道x与n无关,T(n)<=10*c*n,为线性时间复杂度算法。而这又是最坏情况下的分析,故BFPRT可以在最坏情况下以线性时间求得n个数中的第k个数。
我参考别人的代码,加了一点自己理解的注释
#include <iostream>
using namespace std;
void swap(int *array, int s,int e)
{
int temp;
temp = array[s];
array[s] = array[e];
array[e] = temp;
}
// BFPRT: 在数组中寻找第 k 小元素(Top k)
/************************************************************************
FindKthSmallest(Array, k)
pivot = some pivot element of the array.
L = Set of all elements smaller than pivot in Array
R = Set of all elements greater than pivot in Array
if |L| > k FindKthSmalles(L, k)
else if(|L|+1 == k) return pivot
else return FindKthSmallest(R, k-|L|+1)
************************************************************************/
void insertsort(int *array_t, int start, int end)
{
for (int i = start; i <= end; i++) {
int inserted_data = array_t[i];
int j = i;
for (; j > start && inserted_data < array_t[j - 1]; j--) {
array_t[j] = array_t[j - 1];
}
if (j != i) {
array_t[j] = inserted_data;
}
}
}
int partition(int *array_t, int low, int high, int pivot_index)
{
int pivot_value = array_t[pivot_index];
swap(array_t, low, pivot_index);
while (low < high) {
while (low < high && array_t[high] >= pivot_value) {
high--;
}
if (low < high) {
array_t[low++] = array_t[high];
}
while (low < high && array_t[low] <= pivot_value) {
low++;
}
if (low < high) {
array_t[high--] = array_t[low];
}
}
array_t[low] = pivot_value;
return low;
}
// 五划分中项:中位数的中位数(the median of medians algorithm)
// Return the kth value
int select(int *array_t, int left, int right, int k)
{
const int k_group_size = 5;
int size = right - left + 1;
if (size <= k_group_size) {
insertsort(array_t, left, right);
return array_t[k + left - 1];
}
// (right - left) / 2 + left
const int num_group = (size % k_group_size) > 0 ? (size / k_group_size) + 1 : (size / k_group_size);
for (int i = 0; i < num_group; i++) {
int sub_left = left + i * k_group_size;
int sub_right = sub_left + k_group_size - 1;
if (sub_right > right) {
sub_right = right;
}
insertsort(array_t, sub_left, sub_right);
// IMPORTANT !!
// Place these median in front of array_t, so as to recurse to find the median of median
int median = sub_left + ((sub_right - sub_left) >> 1);
swap(array_t, left + i, median);
}
// Get the index of median
int pivot_index = left + ((num_group - 1) >> 1);
// Recurse to call and place the median on the pivot_index, without care about the median value
// Because the value of pivot_index must be the median after select function recursive call.
/*
//k = (num_group + 1) >> 1 中位数 防止k=0,所以num_group+1
//int select(int *array_t, int left, int right, int k)
{
const int k_group_size = 5;
int size = right - left + 1;
if (size <= k_group_size) {
insertsort(array_t, left, right);
return array_t[k + left - 1]; //k + left - 1=(num_group+1)/2 + left -1=num_group/2 + left - 1/2 中位数
}*/
select(array_t, left, left + num_group - 1, (num_group + 1) >> 1);
int mid_index = partition(array_t, left, right, pivot_index);
int _ith = mid_index - left + 1;
// _ith_element == array_t[_ith]
if (k == _ith) {
return array_t[mid_index];
} else if (k < _ith) {
return select(array_t, left, mid_index - 1, k);
} else {
return select(array_t, mid_index + 1, right, k - _ith);
}
}
int main()
{
int k = 8; // 1 <= k <= array.size
int array[20] = { 11,9,10,1,13,8,15,0,16,2,17,5,14,3,6,18,12,7,19,4 };
cout << "原数组:";
for (int i = 0; i < 20; i++)
cout << array[i] << " ";
cout << endl;
// 因为是以 k 为划分,所以还可以求出第 k 小值
cout << "第 " << k << " 小值为:" << array[select(array, 0, 19, k)] << endl;
cout << "变换后的数组:";
for (int i = 0; i < 20; i++)
cout << array[i] << " ";
cout << endl;
return 0;
}