最坏情况为线性时间的选择算法

《算法导论》第3版9.3讲解了最坏情况为线性时间的选择算法

步骤如下

1: 将输入数组的n个元素划分为 n/5 组,每组5个元素,且至多只有一组由剩下的 n%5 个元素组成。

2: 寻找 n/5 组中每一组的中位数:首先对每组元素(至多为5个)进行插入排序,然后确定每组有序元素的中位数。

3: 对第2步中找出的 n/5 个中位数,递归调用 select 以找出其中位数 num(如果有偶数个中位数,为了方便,约定 num 是较小的中位数)

4: 利用修改过的partition版本,按中位数的中位数 num 对输入数组进行划分。让 count 比划分的低区中的元素数目多1,因此 num 是第 count 小的元素,并且有 n - count 个元素在划分的高区。

5: 如果 k == count,则返回 num。如果 k < count,则在低区递归调用 select 以找出第 k 小的元素。如果 k > count,则在高区递归查找第 k - count 小的元素。

代码如下

int select(int A[], int p, int r, int k);

int insert_sort(int A[], int len)
{
    for (int j = 1; j < len; ++j)
    {
        int key = A[j];
        // insert A[j] into the sorted sequence A[0..j-1]
        int i = j - 1;
        // 注意是 i >= 0 而不是 i > 0
        while (i >= 0 && A[i] > key)
        {
            A[i + 1] = A[i];
            --i;
        }

        A[i + 1] = key;
    }

    return A[len / 2];
}

int find_median(int A[], int p, int r)
{
    int len = r - p + 1;

    int *temp = new int[len / 5 + 1];

    int start = p;
    int end = p;
    int j = 0;

    for (int i = 0; i < len; ++i)
    {
        if (i % 5 == 0)
            start = start + i;

        if ((i + 1) % 5 == 0 || i == len - 1)
        {
            end = end + i;
            int small_median = insert_sort(A, end - start + 1);

            temp[j] = small_median;
            ++j;
        }
    }

    int total_median = select(temp, 0, j - 1, (j - 1) / 2);

    delete[] temp;

    return total_median;
}

int partition(int A[], int p, int r, int num)
{
    for (int i = p; i <= r; i++)
    {
        if (A[i] == num)
        {
            swap(A[i], A[r]);
            break;
        }
    }

    int x = A[r];
    int i = p - 1;

    for (int j = p; j < r; ++j)
    {
        if (A[j] <= x)
        {
            ++i;
            swap(A[i], A[j]);
        }
    }

    swap(A[i + 1], A[r]);

    return i + 1;
}

int select(int A[], int p, int r, int k)
{
    assert(p <= r);
    assert(k <= r - p + 1);

    if (p == r)
        return A[p];

    int num = find_median(A, p, r);
    int mid = partition(A, p, r, num);

    int count = mid - p + 1;
    if (k == count)
        return A[mid];
    else if (k < count)
        return select(A, p, mid - 1, k);
    else
        return select(A, mid + 1, r, k - count);
}

猜你喜欢

转载自blog.csdn.net/tao_ba/article/details/81173949
今日推荐