求海量数据中第 K 位元素 & 求 top K 的数据,各解决方法示例和总结

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/afei__/article/details/83117152

一、题目

在一个由 n 个元素组成的集合中,按从小到大顺序排序的话,第 K 个顺序统计位即指第 K 个数,当 K = n 时即最大值,当 K = 1 时即最小值。先给定一个无序的元素集合,求集合中第 K 统计位的值是多少?

同理,若求 top K 的数据的话,即求集合中最大的前 K 个数分别是多少?

例如,给定数组 [ 0, 9, 3, 6, 8, 2, 1, 5, 7, 4 ] ,则第 4 统计位的数字是 3,top 3 大的数是 [ 9, 8, 7 ]。

 

二、最大值和最小值

1. 简介

先说一个很特殊的场景,碰巧 K = 1 或者 K = n 的时候,即我们常说的最小值和最大值。解法很简单,即遍历一遍集合就可以找出最大值和最小值了,遍历一次集合找出最大值或最小值需要比较 n - 1 次,时间复杂度为 O(n)。

假如有种场景我们要同时得到最大值和最小值,最直观的解法就是遍历一次,通过比较 2 * (n - 1) 次就可以得到最大值和最小值了。但是其实我们只需要比较 n * 3 / 2 次就可以了。

2. 代码

#include <iostream>
#include <algorithm>
 
using namespace std;
 
int main() {
    int arr[] = { 12, 2, 32, 64, -10, 33, -6, 0, 2, 10 };
    int length = sizeof(arr) / sizeof(arr[0]);
    int max;
    int min;
    int *p = arr;
    if (length & 1) { // 数组长度为奇数
        max = min = arr[0];
        p++;
    }
    else { // 数组长度为偶数
        if (arr[0] > arr[1]) {
            max = arr[0];
            min = arr[1];
        }
        else {
            max = arr[1];
            min = arr[0];
        }
        p += 2;
    }
    while (p < &arr[length]) {
        int first = *(p++);
        int second = *(p++);
        if (first > second) {
            max = std::max(max, first);
            min = std::min(min, second);
        }
        else {
            max = std::max(max, second);
            min = std::min(min, first);
        }
    }
    cout << "max: " << max << ", min: " << min << endl;
    return 0;
}

3. 评价

该方法只适用于求最大值和最小值的情况。我们每次取两个数,先比较一次得到他们的大小关系,然后我们用较大者和当前最大值比较,用较小者和当前最小值比较,这样我们就只需要比较 3 次就可以比较完 2 个新的元素了。

 

三、排序法

1. 简介

回到主题,K 往往并不是 0 或者 n,那么怎么求解呢?很多人最直接的想法就是,我把集合排个序,然后再取第 K 位的数或者 top K 那还不是轻而易举了。

当然,排序也是可以偷懒的,例如我们求第 10 大的元素的话,我们只需要进行 10 次冒泡排序或者选择排序即可,即遍历 10 次集合就可以达到我们的目的。

2. 代码

冒泡排序或选择排序的代码都很简单,这里不做示例了。

扫描二维码关注公众号,回复: 3744737 查看本文章

3. 评价

值得说明的是,该方法只适合数据集不大或者 K 值很小的情况下使用,假如我们要求海量数据的中位数时,排序法的效率都远不如其它方法了。

 

四、最小堆和最大堆法

1. 简介

最小堆和最大堆的性质还不知道的赶紧去百度了。

假如我们要求集合中最大的前 K 个数的话,我们可以创建一个大小为 K 的最小堆,它的根结点一定是这 K 个元素中的最小值,然后我们只需要遍历 1 遍集合即可。分别和堆中的最小值(即根结点)比较,如果大于它,则我们使用这个新的值替代根结点并刷新堆,使得堆的根结点依旧是这 K 个元素中的最小值。最后遍历完集合后,堆里的这 K 个值就是整个集合中的 top K 了,堆的根结点就是第 K 大的值了。

同理,求最小的前 K 个数的话,就使用一个大小为 K 的最大堆。

2. 代码

public class Main {
 
    public static void main(String[] args) {
        int[] arr = new int[] { 12, 0, 88, -36, 24, 256, 4, -2, 64, 56, 88,
                72, 100, 6, 12, 32, 96, 54, 48, 36 };
        int[] heap = getTopK(arr, 5);
        printArray(heap);
    }
 
    public static int[] getTopK(int[] arr, int k) {
        if (k >= arr.length) {
            return arr;
        }
        int[] heap = new int[k];
        System.arraycopy(arr, 0, heap, 0, heap.length);
        buildMinHeap(heap);
        for (int i = k; i < arr.length; i++) {
            if (arr[i] > heap[0]) {
                heap[0] = arr[i];
                minHeapify(heap, 0, heap.length);
            }
        }
        // 如果需要 Top K 按照从大到小的顺序排序的话
        int heapSize = heap.length;
        for (int i = heap.length - 1; i > 0; i--) {
            int min = heap[0];
            heap[0] = heap[i];
            heap[i] = min;
            minHeapify(heap, 0, --heapSize);
        }
        return heap;
    }
 
    public static void buildMinHeap(int[] heap) {
        // 堆的最后一个分支结点索引为 arr.length / 2 - 1
        for (int i = heap.length / 2 - 1; i >= 0; i--) {
            minHeapify(heap, i, heap.length);
        }
    }
 
    /**
     * 调整堆,使其满足最小堆的性质
     */
    public static void minHeapify(int[] heap, int index, int heapSize) {
        int leftIndex = index * 2 + 1; // 左子节点对应数组中的索引
        int rightIndex = index * 2 + 2; // 右子节点对应数组中的索引
        int minIndex = index;
        // 如果左子结点较小,则将最小值索引设为左子节点
        if (leftIndex < heapSize && heap[leftIndex] < heap[index]) {
            minIndex = leftIndex;
        }
        // 如果右子结点比 min(this, left)还小,则将最小值索引设为右子节点
        if (rightIndex < heapSize && heap[rightIndex] < heap[minIndex]) {
            minIndex = rightIndex;
        }
        // 如果当前结点的值不是最小的,则需要交换最小值,并继续遍历交换后的子结点
        if (minIndex != index) {
            int temp = heap[minIndex];
            heap[minIndex] = heap[index];
            heap[index] = temp;
            minHeapify(heap, minIndex, heapSize);
        }
    }
 
    public static void printArray(int[] arr) {
        for (int i = 0; i < arr.length; i++) {
            System.out.print(arr[i] + " ");
        }
        System.out.println();
    }
 
}

执行结果:

256 100 96 88 88 

3. 评价

由于最小堆或者最大堆的操作时间复杂度均为 O(lg n),n 是堆的大小,大小为 K 的堆即 O(lg K)。且只需要遍历一遍集合,则时间复杂度基本为 O(n * lg K)。可以说该方法的执行效率总是会高于排序法了。

且当 K 值较小时,该算法拥有一个较小的常数系数 lg K ,效率还是很高的。

 

五、快速选择法

1. 简介

快速选择法是一个期望为线性时间的选择算法。它是以快速排序算法为模型修改的。

简单介绍一下原理:首先我们知道快速排序的原理是根据一个 pivot 值,将集合中小于 pivot 的值放置在其左侧,将大于 pivot 的值放置在其右侧,然后再递归地处理左右两侧,最终完成整个集合的排序。快速选择则只处理其中一边,那么根据 pivot 的坐标判断, K 小于 pivot 的话那么我们的数据肯定在左侧,相反则在右侧,等于则直接返回,因为它就是我们要找的数。

如果是求 top K,那我们再以这个第 K 统计位的数为 pivot,划分一次集合即可。

2. 代码

#include <stdio.h>
 
int partition(int *arr, int start, int end) {
    if (start >= end) return arr[start];
    int pivot = arr[start];
    while (end > start) {
        while (end > start && arr[end] >= pivot) {
            end--;
        }
        arr[start] = arr[end]; // 将小于 pivot 的数放在低位
        while (end > start && arr[start] <= pivot) {
            start++;
        }
        arr[end] = arr[start]; // 将大于 pivot 的数放在高位
    }
    arr[start] = pivot;
    return start; // 返回当前轴点位置
}
 
int quickSelect(int *arr, int length, int k) {
    int start = 0;
    int end = length - 1;
    while (end >= start) {
        int p = partition(arr, start, end);
        if (p == k - 1) { // 数组的索引是0开始的,第k大的索引是1开始的
            return arr[p];
        } else if (p < k - 1) {
            start = p + 1;
        } else {
            end = p - 1;
        }
    }
    return 0;
}
 
int main() {
    int arr[] = { 5, 4, 8, 6, 3, 9, 10, 1, 7, 2 };
    printf("第5位的元素为:%d\n", quickSelect(arr, sizeof(arr) / sizeof(arr[0]), 5));
    return 0;
}

3. 评价

快速选择的效率比快速排序已经高了很多,平均时间复杂度通常为 O(n * lg n) 到 O(n)。然后,最坏情况下,它的时间复杂度仍然为 O(n ^ 2)。

即便如此,快速选择及其变种是实际应用中最常使用的高效选择算法,适用于求海量数据的中位数这种 K 也很大的情况。

 

六、BFPRT 算法

1. 简介

BFPRT 算法是一个最坏情况下仍为线性时间的选择算法,它是上述快速选择算法的变种,避免了最坏情况的产生。

BFPRT 算法又称为中位数的中位数算法,是由 5 位大牛 (Blum, Floyd, Pratt, Rivest, Tarjan) 提出,并以他们的名字命名的算法。和选择算法不同的是,它首先将集合以 5 个 5 个元素的划开,先求出每 5 个元素的中位数,再求出这些中位数的中位数,最后以这个中位数的中位数为 pivot 划分集合,以此避免了最坏情况的产生。

2. 代码

#include <stdio.h>
 
int BFPRT(int *arr, int start, int end, int k);
 
void swap(int *a, int *b) {
    if (a != b) {
        int temp = *a;
        *a = *b;
        *b = temp;
    }
}
 
int insertSort(int *arr, int start, int end) {
    for (int i = start + 1; i <= end; i++) {
        for (int j = i; j > start; j--) {
            if (arr[j] < arr[j - 1]) {
                swap(&arr[j], &arr[j - 1]);
            } else {
                break;
            }
        }
    }
    return ((end - start) >> 1) + start; // 返回中位数的下标
}
 
int getMedianIndex(int *arr, int start, int end) {
    int length = end - start + 1;
    if (length <= 5) {
        return insertSort(arr, start, end);
    }
    int subEnd = start; // 中位数的结束位置
    for (int i = start; i + 4 <= end; i += 5) {
        int index = insertSort(arr, i, i + 4);
        swap(&arr[subEnd++], &arr[index]);
    }
    int module = length % 5; // 不能被 5 整除的余数部分
    if (module != 0) {
        int index = insertSort(arr, end - module + 1, end);
        swap(&arr[subEnd++], &arr[index]);
    }
    return getMedianIndex(arr, start, subEnd - 1);
}
 
int partition(int *arr, int start, int end, int pivotIndex) {
    if (start >= end) return arr[start];
    swap(&arr[start], &arr[pivotIndex]);
    int pivot = arr[start];
    while (end > start) {
        while (end > start && arr[end] >= pivot) {
            end--;
        }
        arr[start] = arr[end]; // 将小于 pivot 的数放在低位
        while (end > start && arr[start] <= pivot) {
            start++;
        }
        arr[end] = arr[start]; // 将大于 pivot 的数放在高位
    }
    arr[start] = pivot;
    return start; // 返回当前轴点位置
}
 
int BFPRT(int *arr, int start, int end, int k) {
    if (end - start < 5) {
        insertSort(arr, start, end);
        return arr[k - 1]; // 数组长度太短的话直接处理
    }
    int medianIndex = getMedianIndex(arr, start, end); // 中位数的中位数下标
    int pivotIndex = partition(arr, start, end, medianIndex); // 划分后当前轴点的下标
    if (pivotIndex == k - 1) {
        return arr[pivotIndex];
    } else if (pivotIndex < k - 1) {
        return BFPRT(arr, pivotIndex + 1, end, k);
    } else {
        return BFPRT(arr, start, pivotIndex - 1, k);
    }
    return 0;
}
 
int main() {
    int arr[] = { 5, 4, 8, 6, 3, 9, 0, 1, 7, 2 };
    for (int i = 0; i < 10; i++) {
        printf("第%d位的元素是:%d\n", i + 1, arr[BFPRT(arr, 0, 9, i + 1)]);
    }
    return 0;
}

3. 评价

BFPRT 算法代码比较复杂,但它基本保证了 O(n) 时间下求出结果,只是它的常数系数并不小,所以适合海量数据下,同时 K 也很大的情况,典型的就是求海量数据的中位数了。

猜你喜欢

转载自blog.csdn.net/afei__/article/details/83117152