]> git.kaiwu.me - klib.git/commitdiff
sync kthread with the one in minimap/bwa
authorHeng Li <lh3@me.com>
Sun, 31 Jul 2016 14:51:35 +0000 (10:51 -0400)
committerHeng Li <lh3@me.com>
Sun, 31 Jul 2016 14:51:35 +0000 (10:51 -0400)
kthread.c

index 80f84cb355e206ccdc2d7b9ee56215a037976ede..f991714c5c114f03812a186204203f125a552656 100644 (file)
--- a/kthread.c
+++ b/kthread.c
@@ -47,16 +47,21 @@ static void *ktf_worker(void *data)
 
 void kt_for(int n_threads, void (*func)(void*,long,int), void *data, long n)
 {
-       int i;
-       kt_for_t t;
-       pthread_t *tid;
-       t.func = func, t.data = data, t.n_threads = n_threads, t.n = n;
-       t.w = (ktf_worker_t*)alloca(n_threads * sizeof(ktf_worker_t));
-       tid = (pthread_t*)alloca(n_threads * sizeof(pthread_t));
-       for (i = 0; i < n_threads; ++i)
-               t.w[i].t = &t, t.w[i].i = i;
-       for (i = 0; i < n_threads; ++i) pthread_create(&tid[i], 0, ktf_worker, &t.w[i]);
-       for (i = 0; i < n_threads; ++i) pthread_join(tid[i], 0);
+       if (n_threads > 1) {
+               int i;
+               kt_for_t t;
+               pthread_t *tid;
+               t.func = func, t.data = data, t.n_threads = n_threads, t.n = n;
+               t.w = (ktf_worker_t*)alloca(n_threads * sizeof(ktf_worker_t));
+               tid = (pthread_t*)alloca(n_threads * sizeof(pthread_t));
+               for (i = 0; i < n_threads; ++i)
+                       t.w[i].t = &t, t.w[i].i = i;
+               for (i = 0; i < n_threads; ++i) pthread_create(&tid[i], 0, ktf_worker, &t.w[i]);
+               for (i = 0; i < n_threads; ++i) pthread_join(tid[i], 0);
+       } else {
+               long j;
+               for (j = 0; j < n; ++j) func(data, j, 0);
+       }
 }
 
 /*****************
@@ -67,13 +72,15 @@ struct ktp_t;
 
 typedef struct {
        struct ktp_t *pl;
-       int step, running;
+       int64_t index;
+       int step;
        void *data;
 } ktp_worker_t;
 
 typedef struct ktp_t {
        void *shared;
        void *(*func)(void*, int, void*);
+       int64_t index;
        int n_workers, n_steps;
        ktp_worker_t *workers;
        pthread_mutex_t mutex;
@@ -92,13 +99,12 @@ static void *ktp_worker(void *data)
                        // test whether another worker is doing the same step
                        for (i = 0; i < p->n_workers; ++i) {
                                if (w == &p->workers[i]) continue; // ignore itself
-                               if (p->workers[i].running && p->workers[i].step == w->step)
+                               if (p->workers[i].step <= w->step && p->workers[i].index < w->index)
                                        break;
                        }
-                       if (i == p->n_workers) break; // no other workers doing w->step; then this worker will
+                       if (i == p->n_workers) break; // no workers with smaller indices are doing w->step or the previous steps
                        pthread_cond_wait(&p->cv, &p->mutex);
                }
-               w->running = 1;
                pthread_mutex_unlock(&p->mutex);
 
                // working on w->step
@@ -107,7 +113,7 @@ static void *ktp_worker(void *data)
                // update step and let other workers know
                pthread_mutex_lock(&p->mutex);
                w->step = w->step == p->n_steps - 1 || w->data? (w->step + 1) % p->n_steps : p->n_steps;
-               w->running = 0;
+               if (w->step == 0) w->index = p->index++;
                pthread_cond_broadcast(&p->cv);
                pthread_mutex_unlock(&p->mutex);
        }
@@ -125,16 +131,18 @@ void kt_pipeline(int n_threads, void *(*func)(void*, int, void*), void *shared_d
        aux.n_steps = n_steps;
        aux.func = func;
        aux.shared = shared_data;
+       aux.index = 0;
        pthread_mutex_init(&aux.mutex, 0);
        pthread_cond_init(&aux.cv, 0);
 
-       aux.workers = alloca(n_threads * sizeof(ktp_worker_t));
+       aux.workers = (ktp_worker_t*)alloca(n_threads * sizeof(ktp_worker_t));
        for (i = 0; i < n_threads; ++i) {
                ktp_worker_t *w = &aux.workers[i];
-               w->step = w->running = 0; w->pl = &aux; w->data = 0;
+               w->step = 0; w->pl = &aux; w->data = 0;
+               w->index = aux.index++;
        }
 
-       tid = alloca(n_threads * sizeof(pthread_t));
+       tid = (pthread_t*)alloca(n_threads * sizeof(pthread_t));
        for (i = 0; i < n_threads; ++i) pthread_create(&tid[i], 0, ktp_worker, &aux.workers[i]);
        for (i = 0; i < n_threads; ++i) pthread_join(tid[i], 0);