]> git.kaiwu.me - klib.git/commitdiff
let the main thread help the unfinished workers
authorHeng Li <lh3@me.com>
Wed, 9 Oct 2013 15:26:20 +0000 (11:26 -0400)
committerHeng Li <lh3@me.com>
Wed, 9 Oct 2013 15:26:20 +0000 (11:26 -0400)
kthread.c

index 37e515a3f1253761386a4385874300837fdf451a..d2c8cad738cf9ef1bd4756f5909158ca9191bea2 100644 (file)
--- a/kthread.c
+++ b/kthread.c
@@ -81,32 +81,36 @@ typedef struct ktf_worker_t {
        int i;
 } ktf_worker_t;
 
+static inline uint64_t steal_work(kt_for_t *f) // steal work from the worker with the highest load
+{
+       int i, max = -1, max_i = -1, k = -1;
+       for (i = 0; i < f->n; ++i)
+               if (max < dq_size(f->w[i].q)) // max is not accurate as other workers may steal from the same queue, but it does not matter.
+                       max = dq_size(f->w[i].q), max_i = i;
+       if (dq_deq(f->w[max_i].q, 0, &k) < 0) k = -1;
+       return k;
+}
+
 static void *ktf_worker(void *data)
 {
        ktf_worker_t *w = (ktf_worker_t*)data;
        for (;;) {
                int k = -1;
-               if (dq_deq(w->q, 1, &k) < 0) { // if the queue associated with the worker is full, steal
-                       int i, max, max_i;
-                       for (i = 0, max = -1, max_i = -1; i < w->f->n; ++i) // find the worker with most pending jobs
-                               if (max < dq_size(w->f->w[i].q))
-                                       max = dq_size(w->f->w[i].q), max_i = i;
-                       if (dq_deq(w->f->w[max_i].q, 0, &k) < 0) k = -1; // steal a job
-               }
+               if (dq_deq(w->q, 1, &k) < 0) k = steal_work(w->f);
                if (k >= 0) w->f->func(w->f->global, (uint8_t*)w->f->local + w->f->size * k);
                else if (w->f->finished) break;
        }
        return 0;
 }
 
-void kt_for(int n, int (*func)(void*,void*), void *global, int m, int size, void *local)
+void kt_for(int n_threads, int (*func)(void*,void*), void *global, int m, int size, void *local)
 {
        kt_for_t *f;
        pthread_t *tid;
        int i, k, dq_bits = HT_DQ_BITS;
 
        f = (kt_for_t*)calloc(1, sizeof(kt_for_t));
-       f->n = n - 1, f->size = size;
+       f->n = n_threads - 1, f->size = size;
        f->global = global, f->local = local;
        f->func = func;
 
@@ -122,12 +126,13 @@ void kt_for(int n, int (*func)(void*,void*), void *global, int m, int size, void
 
        for (k = 0; k < m; ++k) {
                int min, min_i;
-               for (i = 0, min = 1<<dq_bits, min_i = -1; i < f->n; ++i)
+               for (i = 0, min = 1<<dq_bits, min_i = -1; i < f->n; ++i) // find the worker with the lowest load
                        if (min > dq_size(f->w[i].q)) min = dq_size(f->w[i].q), min_i = i;
                if (min < 1<<dq_bits) dq_enq(f->w[min_i].q, 0, &k);
                else func(global, (uint8_t*)f->local + f->size * k);
        }
        f->finished = 1;
+       while ((k = steal_work(f)) >= 0) func(global, (uint8_t*)f->local + f->size * k); // help the unfinished workers
 
        for (i = 0; i < f->n; ++i) pthread_join(tid[i], 0);
        for (i = 0; i < f->n; ++i) dq_destroy(f->w[i].q);