]> git.kaiwu.me - klib.git/commitdiff
wait-free heap-free implementation of kt_for()
authorHeng Li <lh3@me.com>
Thu, 10 Oct 2013 16:39:12 +0000 (12:39 -0400)
committerHeng Li <lh3@me.com>
Thu, 10 Oct 2013 16:39:12 +0000 (12:39 -0400)
and also much simpler

kthread.c

index bc94a7044c63c6808b9163b8af7842136c252996..53be2c4b57062759382d035ce5fe447fd06f08da 100644 (file)
--- a/kthread.c
+++ b/kthread.c
 #include <pthread.h>
 #include <stdlib.h>
-#include <stdint.h>
-#include <stdio.h>
 
-#define HT_DQ_BITS 5 // 1<<HT_DQ_BITS is size of deque associated with each worker
-
-/*************************
- *** Fixed-sized deque ***
- *************************/
-
-typedef int dqval_t;
-
-typedef struct { // a ring buffer
-       int lock;
-       int n_bits;
-       int first, count;
-       unsigned mask;
-       dqval_t *a;
-} deque_t;
-
-#define dq_is_full(q) ((uint32_t)(q)->count == 1U<<(q)->n_bits)
-#define dq_size(q) ((q)->count)
-
-deque_t *dq_init(int n_bits)
-{
-       deque_t *d;
-       d = (deque_t*)calloc(1, sizeof(deque_t));
-       d->n_bits = n_bits;
-       d->mask = (1U<<n_bits) - 1;
-       d->a = (dqval_t*)calloc(1<<n_bits, sizeof(dqval_t));
-       return d;
-}
-
-void dq_destroy(deque_t *d) { free(d->a); free(d); }
-
-int dq_enq(deque_t *q, int is_back, const dqval_t *v) // put to the deque
-{
-       int ret = 0;
-       while (__sync_lock_test_and_set(&q->lock, 1)); // this mimics a spin lock
-       if (!dq_is_full(q)) {
-               q->a[(is_back? q->first + q->count : q->first + q->mask) & q->mask] = *v;
-               q->first = is_back? q->first : q->first? q->first - 1 : q->mask;
-               ++q->count;
-       } else ret = -1; // the queue is full
-       __sync_lock_release(&q->lock);
-       return ret;
-}
-
-int dq_deq(deque_t *q, int is_back, dqval_t *v) // get from the queue
-{
-       int ret = 0;
-       while (__sync_lock_test_and_set(&q->lock, 1));
-       if (dq_size(q)) {
-               *v = q->a[is_back? (q->first + q->count + q->mask) & q->mask : q->first];
-               q->first = is_back? q->first : q->first == q->mask? 0 : q->first + 1;
-               --q->count;
-       } else ret = -1; // the queue is empty
-       __sync_lock_release(&q->lock);
-       return ret;
-}
-
-/**********************************
- *** Paralelize simple for loop ***
- **********************************/
-
-struct ktf_worker_t;
+typedef long long ktint64_t;
+struct kt_for_t;
 
 typedef struct {
-       int n; // n: number of workers
-       void *data;
+       struct kt_for_t *t;
+       int tid, i;
+} ktf_worker_t;
+
+typedef struct kt_for_t {
+       int n_threads, n;
+       ktf_worker_t *w;
        void (*func)(void*,int,int);
-       struct ktf_worker_t *w;
-       int finished;
+       void *data;
 } kt_for_t;
 
-typedef struct ktf_worker_t {
-       kt_for_t *f;
-       deque_t *q;
-       int i;
-} ktf_worker_t;
-
-static inline int steal_work(kt_for_t *f) // steal work from the worker with the highest load
+static inline int64_t steal_work(kt_for_t *t)
 {
-       int i, max = -1, max_i = -1, k = -1;
-       for (i = 0; i < f->n; ++i)
-               if (max < dq_size(f->w[i].q)) // max is not accurate as other workers may steal from the same queue, but it does not matter.
-                       max = dq_size(f->w[i].q), max_i = i;
-       if (max_i < 0 || dq_deq(f->w[max_i].q, 0, &k) < 0) k = -1;
-       return k;
+       int i, k, min = t->n, min_i = -1;
+       for (i = 0; i < t->n_threads; ++i)
+               if (min > t->w[i].i) min = t->w[i].i, min_i = i;
+       if (min_i < 0) return -1;
+       k = __sync_fetch_and_add(&t->w[min_i].i, t->n_threads);
+       return k >= t->n? -1 : (ktint64_t)min_i<<32 | k;
 }
 
 static void *ktf_worker(void *data)
 {
        ktf_worker_t *w = (ktf_worker_t*)data;
+       int64_t x;
        for (;;) {
-               int k = -1;
-               if (dq_deq(w->q, 1, &k) < 0) k = steal_work(w->f);
-               if (k >= 0) w->f->func(w->f->data, k, w->i + 1);
-               else if (w->f->finished) break;
+               int i = __sync_fetch_and_add(&w->i, w->t->n_threads);
+               if (i >= w->t->n) break;
+               w->t->func(w->t->data, i, w->tid);
        }
-       return 0;
+       while ((x = steal_work(w->t)) >= 0)
+               w->t->func(w->t->data, (unsigned)x, x>>32);
+       pthread_exit(0);
 }
 
-/**
- * Parallelize a simple "for" loop
- *
- * @param n_threads    total number of threads
- * @param func         function in the form of func(void *data, int item_id, void *item);
- * @param data         data used by $func
- * @param n_items      number of items to process
- *
- * This function parallelizes such a "for" loop:
- *
- *   data_type *data;
- *   for (int i = 0; i < n_items; ++i)
- *     func(data, &items[i], 0);
- *
- * with:
- *
- *   ht_for(n_threads, func, data, n_items);
- */
-void kt_for(int n_threads, void (*func)(void*,int,int), void *data, int n_items)
+void kt_for(int n_threads, void (*func)(void*,int,int), void *data, int n)
 {
-       kt_for_t *f;
+       int i;
+       kt_for_t t;
        pthread_t *tid;
-       int i, k, dq_bits = HT_DQ_BITS;
-
-       f = (kt_for_t*)calloc(1, sizeof(kt_for_t));
-       f->n = n_threads - 1;
-       f->data = data;
-       f->func = func;
-
-       f->w = (ktf_worker_t*)calloc(f->n, sizeof(ktf_worker_t));
-       for (i = 0; i < f->n; ++i)
-               f->w[i].f = f, f->w[i].i = i, f->w[i].q = dq_init(dq_bits);
-
-       tid = (pthread_t*)calloc(f->n, sizeof(pthread_t));
-       for (i = 0; i < f->n; ++i) pthread_create(&tid[i], 0, ktf_worker, &f->w[i]);
-
-       for (k = 0; k < n_items; ++k) {
-               int min, min_i;
-               for (i = 0, min = 1<<dq_bits, min_i = -1; i < f->n; ++i) // find the worker with the lowest load
-                       if (min > dq_size(f->w[i].q)) min = dq_size(f->w[i].q), min_i = i;
-               if (min < 1<<dq_bits) dq_enq(f->w[min_i].q, 0, &k);
-               else f->func(data, k, 0);
-       }
-       f->finished = 1;
-       while ((k = steal_work(f)) >= 0) func(data, k, 0); // help the unfinished workers
-
-       for (i = 0; i < f->n; ++i) pthread_join(tid[i], 0);
-       for (i = 0; i < f->n; ++i) dq_destroy(f->w[i].q);
-       free(tid); free(f->w); free(f);
+       t.func = func, t.data = data, t.n_threads = n_threads, t.n = n;
+       t.w = (ktf_worker_t*)alloca(n_threads * sizeof(ktf_worker_t));
+       tid = (pthread_t*)alloca(n_threads * sizeof(pthread_t));
+       for (i = 0; i < n_threads; ++i)
+               t.w[i].t = &t, t.w[i].tid = t.w[i].i = i;
+       for (i = 0; i < n_threads; ++i) pthread_create(&tid[i], 0, ktf_worker, &t.w[i]);
+       for (i = 0; i < n_threads; ++i) pthread_join(tid[i], 0);
 }