#include <stdint.h>
#include <stdio.h>
-#define HT_DQ_BITS 5 // 1<<HT_DQ_BITS is size of deque associated with each worker
+#define KT_DQ_BITS 5 // 1<<HT_DQ_BITS is size of deque associated with each worker
/*************************
*** Fixed-sized deque ***
*************************/
-typedef int dqval_t;
+typedef uint64_t dqval_t;
typedef struct { // a ring buffer
int lock;
return ret;
}
-/**********************************
- *** Paralelize simple for loop ***
- **********************************/
+/****************************
+ *** Spawn/sync interface ***
+ ****************************/
-struct ktf_worker_t;
+#include "kthread.h"
typedef struct {
- int n, size; // n: number of workers; size: size of each items element
- void *shared;
- void *items;
+ int n_items, item_size, n_finished;
int (*func)(void*,int,void*);
- struct ktf_worker_t *w;
- int finished;
-} kt_for_t;
+ void *shared, *items;
+} kt_task_t;
-typedef struct ktf_worker_t {
- kt_for_t *f;
+typedef struct {
+ struct kthread_t *t;
deque_t *q;
- int i;
-} ktf_worker_t;
-
-static inline int steal_work(kt_for_t *f) // steal work from the worker with the highest load
+ int type;
+ pthread_t tid;
+ pthread_mutex_t lock;
+ pthread_cond_t cv;
+} kt_worker_t;
+
+struct kthread_t {
+ int n_threads;
+ kt_worker_t *w;
+ int n_tasks, max_tasks;
+ kt_task_t *tasks;
+ pthread_t self;
+ pthread_mutex_t lock;
+ pthread_cond_t cv;
+ int to_sync, done;
+};
+
+static inline int steal_task(kthread_t *t)
{
- int i, max = -1, max_i = -1, k = -1;
- for (i = 0; i < f->n; ++i)
- if (max < dq_size(f->w[i].q)) // max is not accurate as other workers may steal from the same queue, but it does not matter.
- max = dq_size(f->w[i].q), max_i = i;
- if (max_i < 0 || dq_deq(f->w[max_i].q, 0, &k) < 0) k = -1;
+ int i, max = -1, max_i = -1;
+ uint64_t k = (uint64_t)-1;
+ for (i = 0; i < t->n_threads; ++i)
+ if (max < dq_size(t->w[i].q)) // max is not accurate as other workers may steal from the same queue, but it does not matter.
+ max = dq_size(t->w[i].q), max_i = i;
+ if (max_i < 0 || dq_deq(t->w[max_i].q, 0, &k) < 0) k = (uint64_t)-1;
return k;
}
-static void *ktf_worker(void *data)
+static inline void do_task(kthread_t *t, uint64_t sid)
+{
+ kt_task_t *s = &t->tasks[sid>>32];
+ s->func(s->shared, (int)sid, (uint8_t*)s->items + s->item_size * (uint32_t)sid);
+}
+
+static void *slave(void *data)
{
- ktf_worker_t *w = (ktf_worker_t*)data;
+ kt_worker_t *w = (kt_worker_t*)data;
for (;;) {
- int k = -1;
- if (dq_deq(w->q, 1, &k) < 0) k = steal_work(w->f);
- if (k >= 0) w->f->func(w->f->shared, k, (uint8_t*)w->f->items + w->f->size * k);
- else if (w->f->finished) break;
+ uint64_t sid;
+ if (dq_deq(w->q, 1, &sid) < 0)
+ sid = steal_task(w->t);
+ if (sid == (uint64_t)-1) { // if still fail to find a task, sleep and wait for the signal
+ if (w->type == 2) break;
+ pthread_mutex_lock(&w->lock);
+ w->type = 0; // wait
+ while (w->type == 0) pthread_cond_wait(&w->cv, &w->lock);
+ pthread_mutex_unlock(&w->lock);
+ if (w->type == 2) break;
+ } else do_task(w->t, sid);
}
return 0;
}
-/**
- * Parallelize a simple "for" loop
- *
- * @param n_threads total number of threads
- * @param func function in the form of func(void *shared, int item_id, void *item);
- * @param shared shared data used by $func
- * @param n_items number of items to process
- * @param item_size size of each item
- * @param items item
- *
- * This function parallelizes such a "for" loop:
- *
- * shared_type *shared;
- * item_type items[n_items];
- * for (int i = 0; i < n_items; ++i)
- * func(shared, &items[i]);
- *
- * with:
- *
- * ht_for(n_threads, func, shared, n_items, sizeof(item_type), items);
- */
-void kt_for(int n_threads, int (*func)(void*,int,void*), void *shared, int n_items, int item_size, void *items)
+static void *master(void *data)
{
- kt_for_t *f;
- pthread_t *tid;
- int i, k, dq_bits = HT_DQ_BITS;
-
- f = (kt_for_t*)calloc(1, sizeof(kt_for_t));
- f->n = n_threads - 1, f->size = item_size;
- f->shared = shared, f->items = items;
- f->func = func;
-
- f->w = (ktf_worker_t*)calloc(f->n, sizeof(ktf_worker_t));
- for (i = 0; i < f->n; ++i)
- f->w[i].f = f, f->w[i].i = i, f->w[i].q = dq_init(dq_bits);
-
- tid = (pthread_t*)calloc(f->n, sizeof(pthread_t));
- for (i = 0; i < f->n; ++i) pthread_create(&tid[i], 0, ktf_worker, &f->w[i]);
-
- for (k = 0; k < n_items; ++k) {
- int min, min_i;
- for (i = 0, min = 1<<dq_bits, min_i = -1; i < f->n; ++i) // find the worker with the lowest load
- if (min > dq_size(f->w[i].q)) min = dq_size(f->w[i].q), min_i = i;
- if (min < 1<<dq_bits) dq_enq(f->w[min_i].q, 0, &k);
- else f->func(shared, k, (uint8_t*)f->items + f->size * k);
+ kthread_t *t = (kthread_t*)data;
+ int i, n_tasks = 0, to_sync = 0;
+ for (i = 0; i < t->n_threads; ++i)
+ pthread_create(&t->w[i].tid, 0, slave, &t->w[i]);
+ while (!to_sync) {
+ int next_tasks, tid, iid;
+ uint64_t sid;
+ pthread_mutex_lock(&t->lock);
+ while (n_tasks == t->n_tasks && !t->to_sync)
+ pthread_cond_wait(&t->cv, &t->lock);
+ next_tasks = t->n_tasks, to_sync = t->to_sync;
+ pthread_mutex_unlock(&t->lock);
+ for (tid = n_tasks; tid < next_tasks; ++tid) {
+ kt_task_t *s = &t->tasks[tid];
+ for (iid = 0; iid < s->n_items; ++iid) {
+ int min, min_i;
+ for (i = 0, min = 1<<KT_DQ_BITS, min_i = -1; i < t->n_threads; ++i)
+ if (min > dq_size(t->w[i].q)) min = dq_size(t->w[i].q), min_i = i;
+ sid = (uint64_t)tid<<32 | iid;
+ if (min < 1<<KT_DQ_BITS) {
+ kt_worker_t *w = &t->w[min_i];
+ dq_enq(w->q, 0, &sid);
+ if (w->type == 0) {
+ pthread_mutex_lock(&w->lock);
+ w->type = 1;
+ pthread_cond_signal(&w->cv);
+ pthread_mutex_unlock(&w->lock);
+ }
+ } else do_task(t, sid);
+ }
+ }
+ while ((sid = steal_task(t)) != (uint64_t)-1) do_task(t, sid);
+ n_tasks = next_tasks;
}
- f->finished = 1;
- while ((k = steal_work(f)) >= 0) func(shared, k, (uint8_t*)f->items + f->size * k); // help the unfinished workers
+ for (i = 0; i < t->n_threads; ++i) {
+ pthread_mutex_lock(&t->w[i].lock);
+ t->w[i].type = 2;
+ pthread_cond_signal(&t->w[i].cv);
+ pthread_mutex_unlock(&t->w[i].lock);
+ }
+ for (i = 0; i < t->n_threads; ++i) pthread_join(t->w[i].tid, 0);
+ return 0;
+}
- for (i = 0; i < f->n; ++i) pthread_join(tid[i], 0);
- for (i = 0; i < f->n; ++i) dq_destroy(f->w[i].q);
- free(tid); free(f->w); free(f);
+kthread_t *kt_init(int n_threads)
+{
+ kthread_t *t;
+ int i;
+ t = calloc(1, sizeof(kthread_t));
+ t->n_threads = n_threads - 1;
+ t->w = calloc(t->n_threads, sizeof(kt_worker_t));
+ pthread_mutex_init(&t->lock, 0);
+ pthread_cond_init(&t->cv, 0);
+ for (i = 0; i < t->n_threads; ++i) {
+ t->w[i].q = dq_init(KT_DQ_BITS);
+ t->w[i].t = t;
+ pthread_mutex_init(&t->w[i].lock, 0);
+ pthread_cond_init(&t->w[i].cv, 0);
+ }
+ pthread_create(&t->self, 0, master, t);
+ return t;
+}
+
+void kt_sync(kthread_t *t)
+{
+ int i;
+ pthread_mutex_lock(&t->lock);
+ t->to_sync = 1;
+ pthread_cond_signal(&t->cv);
+ pthread_mutex_unlock(&t->lock);
+ pthread_join(t->self, 0);
+
+ pthread_cond_destroy(&t->cv);
+ pthread_mutex_destroy(&t->lock);
+ for (i = 0; i < t->n_threads; ++i) {
+ pthread_cond_destroy(&t->w[i].cv);
+ pthread_mutex_destroy(&t->w[i].lock);
+ dq_destroy(t->w[i].q);
+ }
+ free(t->tasks); free(t->w); free(t);
+}
+
+void kt_spawn(kthread_t *t, int (*func)(void*,int,void*), void *shared, int n_items, int item_size, void *items)
+{
+ kt_task_t *p;
+ pthread_mutex_lock(&t->lock);
+ if (t->n_tasks == t->max_tasks) {
+ t->max_tasks = t->max_tasks? t->max_tasks<<1 : 2;
+ t->tasks = realloc(t->tasks, t->max_tasks * sizeof(kt_task_t));
+ }
+ p = &t->tasks[t->n_tasks++];
+ p->func = func, p->shared = shared;
+ p->n_items = n_items, p->item_size = item_size, p->items = items, p->n_finished = 0;
+ pthread_cond_signal(&t->cv);
+ pthread_mutex_unlock(&t->lock);
}