Top-K 问题
Table of Contents
下面分别介绍三种方法,时间复杂度分别为 O(nlogn), O(nlogk), O(n)。
快速排序 O(nlogn)
这种方法,没什么好说的,就是简单的排序,然后取 k 个最大的。
#include <iostream>
#include <algorithm>
using namespace std;
int partition(int a[], int l, int r) {
int i = l, j = r;
int pivot = a[l];
while (i < j) {
while (a[i] < pivot && i < j)
i ++;
while (a[j] > pivot && i < j)
j --;
swap(a[i], a[j]);
}
swap(a[j], pivot);
return j;
}
void qs(int a[], int l, int r) {
if (l < r) {
int p = partition2(a, l, r);
qs2(a, l, p-1);
qs2(a, p+1, r);
}
}
使用堆排 O(nlogk)
使用堆的思想来解决这个问题。如果是 topk 最小,我们只是需要维护一个大小为 k 的大顶堆就可以了。如果是解决 topk 最大,我们需要维护一个大小为 k 的小顶堆。
下面是 topk 最小的方法,topk 最大类似,只是在调整堆方法需要改一下。
class Solution {
public:
vector<int> GetLeastNumbers_Solution(vector<int> input, int k) {
vector<int> heap_k;
if (input.size()==0 || input.size() < k || k <= 0) {
return heap_k;
}
for (int i = 0; i < input.size(); i ++) {
if (heap_k.size() < k) {
heap_k.push_back(input[i]);
max_heapify(heap_k, 0, heap_k.size()-1);
continue;
}
max_heapify(heap_k, 0, heap_k.size()-1);
if (input[i] >= heap_k[0]){
continue;
} else {
heap_k[0] = input[i];
}
}
// heap_sort(heap_k);
return heap_k;
}
void max_heapify(vector<int> &input, int start, int end) {
int root = start;
while (1) {
int child = root * 2 + 1;
if (child > end) {
break;
}
if (child + 1 <= end && input[child+1] > input[child]) {
child += 1;
}
if (input[child] > input[root]) {
swap(input[child], input[root]);
root = child;
} else {
break;
}
}
}
void heap_sort(vector<int> &input) {
int l = input.size();
for (int i = (l - 1)/2; i > 0; i --) {
max_heapify(input, i, l-1);
}
for (int i = l - 1; i > 0; i --) {
swap(input[i], input[0]);
max_heapify(input, 0, i-1);
}
}
};
BFPRT 算法 O(n)
基于快排的思想,算法的步骤如下:
- 将 n 个元素划为 ⌊ n/5⌋ 组,每组 5 个,至多只有一组由 n\bmod5 个元素组成。
- 寻找这 ⌈ n/5⌉ 个组中每一个组的中位数,这个过程可以用插入排序。
- 对步骤 2 中的 ⌈ n/5⌉ 个中位数,重复步骤 1 和步骤 2,递归下去,直到剩下一个数字。
- 最终剩下的数字即为 pivot,把大于它的数全放左边,小于等于它的数全放右边。
判断 pivot 的位置与 k 的大小,有选择的对左边或右边递归。
int BFPRT(int a[], int l, int r, int k) { int p = FindMid(a, l, r); //寻找中位数的中位数 int i = Partion(a, l, r, p); int m = i - l + 1; if(m == k) return a[i]; if(m > k) return BFPRT(a, l, i - 1, k); return BFPRT(a, i + 1, r, k - m); }
这个算法的最坏时间复杂度为 O(n)。