Top-K 问题

Table of Contents


下面分别介绍三种方法,时间复杂度分别为 O(nlogn), O(nlogk), O(n)。

快速排序 O(nlogn)

这种方法,没什么好说的,就是简单的排序,然后取 k 个最大的。

    #include <iostream>
    #include <algorithm>
    using namespace std;

    int partition(int a[], int l, int r) {
    int i = l, j = r;
    int pivot = a[l];
    while (i < j) {
        while (a[i] < pivot && i < j)
            i ++;
        while (a[j] > pivot && i < j)
            j --;
        swap(a[i], a[j]);
    }
    swap(a[j], pivot);
    return j;
}
void qs(int a[], int l, int r) {
    if (l < r) {
        int p = partition2(a, l, r);
        qs2(a, l, p-1);
        qs2(a, p+1, r);
    }
}

使用堆排 O(nlogk)

使用堆的思想来解决这个问题。如果是 topk 最小,我们只是需要维护一个大小为 k 的大顶堆就可以了。如果是解决 topk 最大,我们需要维护一个大小为 k 的小顶堆。

下面是 topk 最小的方法,topk 最大类似,只是在调整堆方法需要改一下。

class Solution {
public:
    vector<int> GetLeastNumbers_Solution(vector<int> input, int k) {
        vector<int> heap_k;
        if (input.size()==0 || input.size() < k || k <= 0) {
            return heap_k;
        }
        for (int i = 0; i < input.size(); i ++) {
            if (heap_k.size() < k) {
                heap_k.push_back(input[i]);
                    max_heapify(heap_k, 0, heap_k.size()-1);
                continue;
            }
            max_heapify(heap_k, 0, heap_k.size()-1);
            if (input[i] >= heap_k[0]){
                continue;
            } else {
                heap_k[0] = input[i];
            }
        }
        // heap_sort(heap_k);
        return heap_k;
    }

    void max_heapify(vector<int> &input, int start, int end) {
        int root = start;

        while (1) {
            int child = root * 2 + 1;
            if (child > end) {
                break;
            }
            if (child + 1 <= end && input[child+1] > input[child]) {
                child += 1;
            }
            if (input[child] > input[root]) {
                swap(input[child], input[root]);
                root = child;
            } else {
                break;
            }
        }
    }

    void heap_sort(vector<int> &input) {
        int l = input.size();
        for (int i = (l - 1)/2; i > 0; i --) {
            max_heapify(input, i, l-1);
        }
        for (int i = l - 1; i > 0; i --) {
            swap(input[i], input[0]);
            max_heapify(input, 0, i-1);
        }
    }
};

BFPRT 算法 O(n)

基于快排的思想,算法的步骤如下:

  1. 将 n 个元素划为 ⌊ n/5⌋ 组,每组 5 个,至多只有一组由 n\bmod5 个元素组成。
  1. 寻找这 ⌈ n/5⌉ 个组中每一个组的中位数,这个过程可以用插入排序。
  2. 对步骤 2 中的 ⌈ n/5⌉ 个中位数,重复步骤 1 和步骤 2,递归下去,直到剩下一个数字。
  3. 最终剩下的数字即为 pivot,把大于它的数全放左边,小于等于它的数全放右边。
  4. 判断 pivot 的位置与 k 的大小,有选择的对左边或右边递归。

        int BFPRT(int a[], int l, int r, int k)
    {
        int p = FindMid(a, l, r);    //寻找中位数的中位数
        int i = Partion(a, l, r, p);
    
        int m = i - l + 1;
        if(m == k) return a[i];
        if(m > k)  return BFPRT(a, l, i - 1, k);
        return BFPRT(a, i + 1, r, k - m);
    }
    

    这个算法的最坏时间复杂度为 O(n)。