From 7d5daa15e63af4dc372b8a2de6bb5e2bc79e4c4e Mon Sep 17 00:00:00 2001 From: Heng Li Date: Fri, 11 Oct 2013 13:17:26 -0400 Subject: [PATCH] added a more flexible scheduler; not tested not thoroughly --- kthread.c | 131 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 127 insertions(+), 4 deletions(-) diff --git a/kthread.c b/kthread.c index 43797ea..e83a01d 100644 --- a/kthread.c +++ b/kthread.c @@ -1,5 +1,10 @@ #include #include +#include "kthread.h" + +/************** + *** kh_for *** + **************/ struct kt_for_t; @@ -15,12 +20,11 @@ typedef struct kt_for_t { void *data; } kt_for_t; -static inline int steal_work(kt_for_t *t) +static inline int ktf_steal_work(kt_for_t *t) { - int i, k, min = t->n, min_i = -1; + int i, k, min = 0x7fffffff, min_i = 0; for (i = 0; i < t->n_threads; ++i) if (min > t->w[i].i) min = t->w[i].i, min_i = i; - if (min_i < 0) return -1; k = __sync_fetch_and_add(&t->w[min_i].i, t->n_threads); return k >= t->n? -1 : k; } @@ -34,7 +38,7 @@ static void *ktf_worker(void *data) if (i >= w->t->n) break; w->t->func(w->t->data, i, w - w->t->w); } - while ((i = steal_work(w->t)) >= 0) + while ((i = ktf_steal_work(w->t)) >= 0) w->t->func(w->t->data, i, w - w->t->w); pthread_exit(0); } @@ -52,3 +56,122 @@ void kt_for(int n_threads, int n, void (*func)(void*,int,int), void *data) for (i = 0; i < n_threads; ++i) pthread_create(&tid[i], 0, ktf_worker, &t.w[i]); for (i = 0; i < n_threads; ++i) pthread_join(tid[i], 0); } + +/************************ + *** kt_spawn/kt_sync *** + ************************/ + +typedef long long ktint64_t; + +typedef struct { + kthread_t *t; + int i; // slot ID + pthread_t tid; + int pending; +} kts_worker_t; + +typedef struct { + int n; + void (*func)(void*,int,int); + void *data; +} kts_task_t; + +struct kthread_t { + kts_worker_t *w; + int n_threads, to_sync; + int n_tasks, n_slots; + kts_task_t tasks[KT_MAX_TASKS]; + pthread_mutex_t lock; + pthread_cond_t cv; +}; + +static inline void process_slot(const kthread_t *t, int i, int tid) +{ + int j; + for (j = 0;; ++j) { + kts_task_t task = t->tasks[j]; + if (i < task.n) { + task.func(task.data, i, tid); + break; + } else i -= task.n; + } +} + +static void *kts_worker(void *data) +{ + kts_worker_t *w = (kts_worker_t*)data; + for (;;) { + int i, to_sync, n_slots; + // update the task and slot information + pthread_mutex_lock(&w->t->lock); + while (w->i >= w->t->n_slots && !w->t->to_sync) + pthread_cond_wait(&w->t->cv, &w->t->lock); + to_sync = w->t->to_sync, n_slots = w->t->n_slots; + pthread_mutex_unlock(&w->t->lock); + // process the pending slot if there is any + if (w->pending >= 0 && w->pending < n_slots) { + process_slot(w->t, w->pending, w - w->t->w); + w->pending = -1; + } + // process slots assigned to the current worker + for (;;) { + i = __sync_fetch_and_add(&w->i, w->t->n_threads); + if (i >= n_slots) break; + process_slot(w->t, i, w - w->t->w); + } + // steal slots from other workers + for (;;) { + int min = 0x7fffffff, min_i = 0; + for (i = 0; i < w->t->n_threads; ++i) + if (min > w->t->w[i].i) min = w->t->w[i].i, min_i = i; + i = __sync_fetch_and_add(&w->t->w[min_i].i, w->t->n_threads); + if (i >= n_slots) { + w->pending = i; + break; + } else process_slot(w->t, i, w - w->t->w); + } + if (to_sync) break; + } + pthread_exit(0); +} + +kthread_t *kt_init(int n_threads) +{ + kthread_t *t; + int i; + t = (kthread_t*)calloc(1, sizeof(kthread_t)); + t->n_threads = n_threads; + t->w = (kts_worker_t*)calloc(t->n_threads, sizeof(kts_worker_t)); + pthread_mutex_init(&t->lock, 0); + pthread_cond_init(&t->cv, 0); + for (i = 0; i < t->n_threads; ++i) { + t->w[i].i = i, t->w[i].t = t, t->w[i].pending = -1; + pthread_create(&t->w[i].tid, 0, kts_worker, &t->w[i]); + } + return t; +} + +void kt_sync(kthread_t *t) +{ + int i; + pthread_mutex_lock(&t->lock); + t->to_sync = 1; + pthread_cond_broadcast(&t->cv); + pthread_mutex_unlock(&t->lock); + for (i = 0; i < t->n_threads; ++i) + pthread_join(t->w[i].tid, 0); + free(t->w); free(t); +} + +int kt_spawn(kthread_t *t, int n, void (*func)(void*,int,int), void *data) +{ + kts_task_t *task; + if (t->n_tasks >= KT_MAX_TASKS) return -1; + pthread_mutex_lock(&t->lock); + task = &t->tasks[t->n_tasks++]; + task->n = n, task->func = func, task->data = data; + t->n_slots += n; + pthread_cond_broadcast(&t->cv); + pthread_mutex_unlock(&t->lock); + return 0; +} -- 2.47.3