Saturday, November 12, 2011

Top-K Selection

Ok, so I need to select the top-K values from a list of size N where K is much smaller than N. Two approaches that immediately come to mind are:

  1. Sort the list and select the first K elements. Running time O(N*lg(N)).
  2. Insert the elements of the list into a heap and pop the top K elements. Running time O(N + K*lg(N)).

As a variant on option 2 a colleague proposed using a small heap that would have at most K elements at any given time. When the heap is full an element would be popped off before a new element could be added. So if I wanted the top-K smallest values I would use a max-heap and if the next value from the input is smaller than the largest value on the heap I would pop off the largest value and insert the smaller value. The running time for this approach is O(N*lg(K)).

For my use case both N and K are fairly small. The size of N will be approximately 10,000 elements and K will typically be 10, but can be set anywhere from 1 to 100. A C++ test program that implements all three approaches and tests them for various sizes of K is provided at the bottom of this post. The table below shows the results for K equal to 10, 50, and 100. You can see that all three approaches have roughly the same running time and increasing the size of K has little impact on the actual running time.

K=10K=50K=100
sort0.4689770.4669180.467177
fullheap0.0993250.1028390.106735
smallheap0.0180630.0289480.040435

Here is the chart for all values of K:


So clearly the third approach with the small heap is the winner. With a small K it is essentially linear time and an order of magnitude faster than the naive sort. The graph shows both the average time and the 99th-percentile so you can see the variation in times is fairly small. This first test covers the sizes for my use case, but out of curiosity, I also tested the more interesting case with a fixed size K and varying the size of N. The graph for N from 100,000 to 1,000,000 in increments of 100,000 tells the whole story:


Source code:

#include <algorithm>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <numeric>
#include <vector>

using namespace std;

typedef void (*topk_func)(vector<int> &, vector<int> &, int);

bool greater_than(const int &v1, const int &v2) {
    return v1 > v2;
}

void topk_sort(vector<int> &data, vector<int> &top, int k) {
    sort(data.begin(), data.end());
    vector<int>::iterator i = data.begin();
    int j = 0;
    for (; j < k; ++i, ++j) {
        top.push_back(*i);
    }
}

void topk_fullheap(vector<int> &data, vector<int> &top, int k) {
    make_heap(data.begin(), data.end(), greater_than);
    for (int j = 0; j < k; ++j) {
        top.push_back(*data.begin());
        pop_heap(data.begin(), data.end(), greater_than);
        data.pop_back();
    }
}

void topk_smallheap(vector<int> &data, vector<int> &top, int k) {
    for (vector<int>::iterator i = data.begin(); i != data.end(); ++i) {
        if (top.size() < k) {
            top.push_back(*i);
            push_heap(top.begin(), top.end());
        } else if (*i < *top.begin()) {
            pop_heap(top.begin(), top.end());
            top.pop_back();
            top.push_back(*i);
            push_heap(top.begin(), top.end());
        }
    }
    sort_heap(top.begin(), top.end());
}

void run_test(const string &name, topk_func f, int trials, int n, int k) {
    vector<double> times;

    for (int t = 0; t < trials; ++t) {
        vector<int> data;
        for (int i = 0; i < n; ++i) {
            data.push_back(rand());
        }

        clock_t start = clock();
        vector<int> top;
        f(data, top, k);
        clock_t end = clock();
        times.push_back(1000.0 * (end - start) / CLOCKS_PER_SEC);
    }

    sort(times.begin(), times.end());
    double sum = accumulate(times.begin(), times.end(), 0.0);
    double avg = sum / times.size();
    double pct90 = times[static_cast<int>(trials * 0.90)];
    double pct99 = times[static_cast<int>(trials * 0.99)];
    cout << name << " "
         << k << " "
         << avg << " "
         << pct90 << " "
         << pct99 << endl;
}

int main(int argc, char **argv) {

    cout << "method k avg 90-%tile 99-%tile" << endl;

    int trials = 1000;
    int n = 10000;
    for (int k = 1; k <= 100; ++k) {
        run_test("sort", topk_sort, trials, n, k);
        run_test("fullheap", topk_fullheap, trials, n, k);
        run_test("smallheap", topk_smallheap, trials, n, k);
    }

    return 0;
}

No comments:

Post a Comment