int i;
} ktf_worker_t;
+static inline uint64_t steal_work(kt_for_t *f) // steal work from the worker with the highest load
+{
+ int i, max = -1, max_i = -1, k = -1;
+ for (i = 0; i < f->n; ++i)
+ if (max < dq_size(f->w[i].q)) // max is not accurate as other workers may steal from the same queue, but it does not matter.
+ max = dq_size(f->w[i].q), max_i = i;
+ if (dq_deq(f->w[max_i].q, 0, &k) < 0) k = -1;
+ return k;
+}
+
static void *ktf_worker(void *data)
{
ktf_worker_t *w = (ktf_worker_t*)data;
for (;;) {
int k = -1;
- if (dq_deq(w->q, 1, &k) < 0) { // if the queue associated with the worker is full, steal
- int i, max, max_i;
- for (i = 0, max = -1, max_i = -1; i < w->f->n; ++i) // find the worker with most pending jobs
- if (max < dq_size(w->f->w[i].q))
- max = dq_size(w->f->w[i].q), max_i = i;
- if (dq_deq(w->f->w[max_i].q, 0, &k) < 0) k = -1; // steal a job
- }
+ if (dq_deq(w->q, 1, &k) < 0) k = steal_work(w->f);
if (k >= 0) w->f->func(w->f->global, (uint8_t*)w->f->local + w->f->size * k);
else if (w->f->finished) break;
}
return 0;
}
-void kt_for(int n, int (*func)(void*,void*), void *global, int m, int size, void *local)
+void kt_for(int n_threads, int (*func)(void*,void*), void *global, int m, int size, void *local)
{
kt_for_t *f;
pthread_t *tid;
int i, k, dq_bits = HT_DQ_BITS;
f = (kt_for_t*)calloc(1, sizeof(kt_for_t));
- f->n = n - 1, f->size = size;
+ f->n = n_threads - 1, f->size = size;
f->global = global, f->local = local;
f->func = func;
for (k = 0; k < m; ++k) {
int min, min_i;
- for (i = 0, min = 1<<dq_bits, min_i = -1; i < f->n; ++i)
+ for (i = 0, min = 1<<dq_bits, min_i = -1; i < f->n; ++i) // find the worker with the lowest load
if (min > dq_size(f->w[i].q)) min = dq_size(f->w[i].q), min_i = i;
if (min < 1<<dq_bits) dq_enq(f->w[min_i].q, 0, &k);
else func(global, (uint8_t*)f->local + f->size * k);
}
f->finished = 1;
+ while ((k = steal_work(f)) >= 0) func(global, (uint8_t*)f->local + f->size * k); // help the unfinished workers
for (i = 0; i < f->n; ++i) pthread_join(tid[i], 0);
for (i = 0; i < f->n; ++i) dq_destroy(f->w[i].q);