]> git.kaiwu.me - klib.git/commitdiff
more versatile APIs
authorHeng Li <lh3@me.com>
Thu, 10 Oct 2013 03:53:56 +0000 (23:53 -0400)
committerHeng Li <lh3@me.com>
Thu, 10 Oct 2013 03:53:56 +0000 (23:53 -0400)
kthread.c
kthread.h [new file with mode: 0644]

index 9ebdc18ba40014828eede5ddc800b69120bbbf30..40fb353b0d72106dfc7b6c04813485c87c2609e9 100644 (file)
--- a/kthread.c
+++ b/kthread.c
@@ -3,13 +3,13 @@
 #include <stdint.h>
 #include <stdio.h>
 
-#define HT_DQ_BITS 5 // 1<<HT_DQ_BITS is size of deque associated with each worker
+#define KT_DQ_BITS 5 // 1<<HT_DQ_BITS is size of deque associated with each worker
 
 /*************************
  *** Fixed-sized deque ***
  *************************/
 
-typedef int dqval_t;
+typedef uint64_t dqval_t;
 
 typedef struct { // a ring buffer
        int lock;
@@ -60,99 +60,169 @@ int dq_deq(deque_t *q, int is_back, dqval_t *v) // get from the queue
        return ret;
 }
 
-/**********************************
- *** Paralelize simple for loop ***
- **********************************/
+/****************************
+ *** Spawn/sync interface ***
+ ****************************/
 
-struct ktf_worker_t;
+#include "kthread.h"
 
 typedef struct {
-       int n, size; // n: number of workers; size: size of each items element
-       void *shared;
-       void *items;
+       int n_items, item_size, n_finished;
        int (*func)(void*,int,void*);
-       struct ktf_worker_t *w;
-       int finished;
-} kt_for_t;
+       void *shared, *items;
+} kt_task_t;
 
-typedef struct ktf_worker_t {
-       kt_for_t *f;
+typedef struct {
+       struct kthread_t *t;
        deque_t *q;
-       int i;
-} ktf_worker_t;
-
-static inline int steal_work(kt_for_t *f) // steal work from the worker with the highest load
+       int type;
+       pthread_t tid;
+       pthread_mutex_t lock;
+       pthread_cond_t cv;
+} kt_worker_t;
+
+struct kthread_t {
+       int n_threads;
+       kt_worker_t *w;
+       int n_tasks, max_tasks;
+       kt_task_t *tasks;
+       pthread_t self;
+       pthread_mutex_t lock;
+       pthread_cond_t cv;
+       int to_sync, done;
+};
+
+static inline int steal_task(kthread_t *t)
 {
-       int i, max = -1, max_i = -1, k = -1;
-       for (i = 0; i < f->n; ++i)
-               if (max < dq_size(f->w[i].q)) // max is not accurate as other workers may steal from the same queue, but it does not matter.
-                       max = dq_size(f->w[i].q), max_i = i;
-       if (max_i < 0 || dq_deq(f->w[max_i].q, 0, &k) < 0) k = -1;
+       int i, max = -1, max_i = -1;
+       uint64_t k = (uint64_t)-1;
+       for (i = 0; i < t->n_threads; ++i)
+               if (max < dq_size(t->w[i].q)) // max is not accurate as other workers may steal from the same queue, but it does not matter.
+                       max = dq_size(t->w[i].q), max_i = i;
+       if (max_i < 0 || dq_deq(t->w[max_i].q, 0, &k) < 0) k = (uint64_t)-1;
        return k;
 }
 
-static void *ktf_worker(void *data)
+static inline void do_task(kthread_t *t, uint64_t sid)
+{
+       kt_task_t *s = &t->tasks[sid>>32];
+       s->func(s->shared, (int)sid, (uint8_t*)s->items + s->item_size * (uint32_t)sid);
+}
+
+static void *slave(void *data)
 {
-       ktf_worker_t *w = (ktf_worker_t*)data;
+       kt_worker_t *w = (kt_worker_t*)data;
        for (;;) {
-               int k = -1;
-               if (dq_deq(w->q, 1, &k) < 0) k = steal_work(w->f);
-               if (k >= 0) w->f->func(w->f->shared, k, (uint8_t*)w->f->items + w->f->size * k);
-               else if (w->f->finished) break;
+               uint64_t sid;
+               if (dq_deq(w->q, 1, &sid) < 0)
+                       sid = steal_task(w->t);
+               if (sid == (uint64_t)-1) { // if still fail to find a task, sleep and wait for the signal
+                       if (w->type == 2) break;
+                       pthread_mutex_lock(&w->lock);
+                       w->type = 0; // wait
+                       while (w->type == 0) pthread_cond_wait(&w->cv, &w->lock);
+                       pthread_mutex_unlock(&w->lock);
+                       if (w->type == 2) break;
+               } else do_task(w->t, sid);
        }
        return 0;
 }
 
-/**
- * Parallelize a simple "for" loop
- *
- * @param n_threads    total number of threads
- * @param func         function in the form of func(void *shared, int item_id, void *item);
- * @param shared       shared data used by $func
- * @param n_items      number of items to process
- * @param item_size    size of each item
- * @param items        item
- *
- * This function parallelizes such a "for" loop:
- *
- *   shared_type *shared;
- *   item_type items[n_items];
- *   for (int i = 0; i < n_items; ++i)
- *     func(shared, &items[i]);
- *
- * with:
- *
- *   ht_for(n_threads, func, shared, n_items, sizeof(item_type), items);
- */
-void kt_for(int n_threads, int (*func)(void*,int,void*), void *shared, int n_items, int item_size, void *items)
+static void *master(void *data)
 {
-       kt_for_t *f;
-       pthread_t *tid;
-       int i, k, dq_bits = HT_DQ_BITS;
-
-       f = (kt_for_t*)calloc(1, sizeof(kt_for_t));
-       f->n = n_threads - 1, f->size = item_size;
-       f->shared = shared, f->items = items;
-       f->func = func;
-
-       f->w = (ktf_worker_t*)calloc(f->n, sizeof(ktf_worker_t));
-       for (i = 0; i < f->n; ++i)
-               f->w[i].f = f, f->w[i].i = i, f->w[i].q = dq_init(dq_bits);
-
-       tid = (pthread_t*)calloc(f->n, sizeof(pthread_t));
-       for (i = 0; i < f->n; ++i) pthread_create(&tid[i], 0, ktf_worker, &f->w[i]);
-
-       for (k = 0; k < n_items; ++k) {
-               int min, min_i;
-               for (i = 0, min = 1<<dq_bits, min_i = -1; i < f->n; ++i) // find the worker with the lowest load
-                       if (min > dq_size(f->w[i].q)) min = dq_size(f->w[i].q), min_i = i;
-               if (min < 1<<dq_bits) dq_enq(f->w[min_i].q, 0, &k);
-               else f->func(shared, k, (uint8_t*)f->items + f->size * k);
+       kthread_t *t = (kthread_t*)data;
+       int i, n_tasks = 0, to_sync = 0;
+       for (i = 0; i < t->n_threads; ++i)
+               pthread_create(&t->w[i].tid, 0, slave, &t->w[i]);
+       while (!to_sync) {
+               int next_tasks, tid, iid;
+               uint64_t sid;
+               pthread_mutex_lock(&t->lock);
+               while (n_tasks == t->n_tasks && !t->to_sync)
+                       pthread_cond_wait(&t->cv, &t->lock);
+               next_tasks = t->n_tasks, to_sync = t->to_sync;
+               pthread_mutex_unlock(&t->lock);
+               for (tid = n_tasks; tid < next_tasks; ++tid) {
+                       kt_task_t *s = &t->tasks[tid];
+                       for (iid = 0; iid < s->n_items; ++iid) {
+                               int min, min_i;
+                               for (i = 0, min = 1<<KT_DQ_BITS, min_i = -1; i < t->n_threads; ++i)
+                                       if (min > dq_size(t->w[i].q)) min = dq_size(t->w[i].q), min_i = i;
+                               sid = (uint64_t)tid<<32 | iid;
+                               if (min < 1<<KT_DQ_BITS) {
+                                       kt_worker_t *w = &t->w[min_i];
+                                       dq_enq(w->q, 0, &sid);
+                                       if (w->type == 0) {
+                                               pthread_mutex_lock(&w->lock);
+                                               w->type = 1;
+                                               pthread_cond_signal(&w->cv);
+                                               pthread_mutex_unlock(&w->lock);
+                                       }
+                               } else do_task(t, sid);
+                       }
+               }
+               while ((sid = steal_task(t)) != (uint64_t)-1) do_task(t, sid);
+               n_tasks = next_tasks;
        }
-       f->finished = 1;
-       while ((k = steal_work(f)) >= 0) func(shared, k, (uint8_t*)f->items + f->size * k); // help the unfinished workers
+       for (i = 0; i < t->n_threads; ++i) {
+               pthread_mutex_lock(&t->w[i].lock);
+               t->w[i].type = 2;
+               pthread_cond_signal(&t->w[i].cv);
+               pthread_mutex_unlock(&t->w[i].lock);
+       }
+       for (i = 0; i < t->n_threads; ++i) pthread_join(t->w[i].tid, 0);
+       return 0;
+}
 
-       for (i = 0; i < f->n; ++i) pthread_join(tid[i], 0);
-       for (i = 0; i < f->n; ++i) dq_destroy(f->w[i].q);
-       free(tid); free(f->w); free(f);
+kthread_t *kt_init(int n_threads)
+{
+       kthread_t *t;
+       int i;
+       t = calloc(1, sizeof(kthread_t));
+       t->n_threads = n_threads - 1;
+       t->w = calloc(t->n_threads, sizeof(kt_worker_t));
+       pthread_mutex_init(&t->lock, 0);
+       pthread_cond_init(&t->cv, 0);
+       for (i = 0; i < t->n_threads; ++i) {
+               t->w[i].q = dq_init(KT_DQ_BITS);
+               t->w[i].t = t;
+               pthread_mutex_init(&t->w[i].lock, 0);
+               pthread_cond_init(&t->w[i].cv, 0);
+       }
+       pthread_create(&t->self, 0, master, t);
+       return t;
+}
+
+void kt_sync(kthread_t *t)
+{
+       int i;
+       pthread_mutex_lock(&t->lock);
+       t->to_sync = 1;
+       pthread_cond_signal(&t->cv);
+       pthread_mutex_unlock(&t->lock);
+       pthread_join(t->self, 0);
+
+       pthread_cond_destroy(&t->cv);
+       pthread_mutex_destroy(&t->lock);
+       for (i = 0; i < t->n_threads; ++i) {
+               pthread_cond_destroy(&t->w[i].cv);
+               pthread_mutex_destroy(&t->w[i].lock);
+               dq_destroy(t->w[i].q);
+       }
+       free(t->tasks); free(t->w); free(t);
+}
+
+void kt_spawn(kthread_t *t, int (*func)(void*,int,void*), void *shared, int n_items, int item_size, void *items)
+{
+       kt_task_t *p;
+       pthread_mutex_lock(&t->lock);
+       if (t->n_tasks == t->max_tasks) {
+               t->max_tasks = t->max_tasks? t->max_tasks<<1 : 2;
+               t->tasks = realloc(t->tasks, t->max_tasks * sizeof(kt_task_t));
+       }
+       p = &t->tasks[t->n_tasks++];
+       p->func = func, p->shared = shared;
+       p->n_items = n_items, p->item_size = item_size, p->items = items, p->n_finished = 0;
+       pthread_cond_signal(&t->cv);
+       pthread_mutex_unlock(&t->lock);
 }
diff --git a/kthread.h b/kthread.h
new file mode 100644 (file)
index 0000000..bc2e5c7
--- /dev/null
+++ b/kthread.h
@@ -0,0 +1,11 @@
+#ifndef KTHREAD_H
+#define KTHREAD_H
+
+struct kthread_t;
+typedef struct kthread_t kthread_t;
+
+kthread_t *kt_init(int n_threads);
+void kt_spawn(kthread_t *t, int (*func)(void*,int,void*), void *shared, int n_items, int item_size, void *items);
+void kt_sync(kthread_t *t);
+
+#endif