]> git.kaiwu.me - klib.git/commitdiff
added 16-bit sse2; query-end seems wrong, sometime
authorHeng Li <lh3@live.co.uk>
Sun, 8 May 2011 22:20:54 +0000 (18:20 -0400)
committerHeng Li <lh3@live.co.uk>
Sun, 8 May 2011 22:20:54 +0000 (18:20 -0400)
ksw.c

diff --git a/ksw.c b/ksw.c
index f2870609301da9ecb273ab04d9022dd45134a321..92c376636d531e346c90825b3d5e7fbe9ecadf13 100644 (file)
--- a/ksw.c
+++ b/ksw.c
 
 struct _ksw_query_t {
        int qlen, slen;
-       uint8_t shift, mdiff, max;
+       uint8_t shift, mdiff, max, size;
        __m128i *qp, *H0, *H1, *E, *Hmax;
 };
 
-ksw_query_t *ksw_qinit(int p, int qlen, const uint8_t *query, int m, const int8_t *mat)
+ksw_query_t *ksw_qinit(int size, int qlen, const uint8_t *query, int m, const int8_t *mat)
 {
        ksw_query_t *q;
-       int8_t *t;
-       int qlen16, slen, a, tmp;
+       int slen, a, tmp, p;
 
-       slen = (qlen + p - 1) / p;
-       qlen16 = (qlen + 15) >> 4 << 4;
-       q = malloc(sizeof(ksw_query_t) + 256 + qlen16 * (m + 4)); // a single block of memory
+       size = size > 1? 2 : 1;
+       p = 8 * (3 - size); // # values per __m128i
+       slen = (qlen + p - 1) / p; // segmented length
+       q = malloc(sizeof(ksw_query_t) + 256 + 16 * slen * (m + 4)); // a single block of memory
        q->qp = (__m128i*)(((size_t)q + sizeof(ksw_query_t) + 15) >> 4 << 4); // align memory
-       q->H0 = q->qp + (qlen16 * m) / 16;
-       q->H1 = q->H0 + qlen16 / 16;
-       q->E  = q->H1 + qlen16 / 16;
-       q->Hmax = q->E + qlen16 / 16;
-       q->slen = slen; q->qlen = qlen;
+       q->H0 = q->qp + slen * m;
+       q->H1 = q->H0 + slen;
+       q->E  = q->H1 + slen;
+       q->Hmax = q->E + slen;
+       q->slen = slen; q->qlen = qlen; q->size = size;
        // compute shift
        tmp = m * m;
        for (a = 0, q->shift = 127, q->mdiff = 0; a < tmp; ++a) { // find the minimum and maximum score
@@ -68,13 +68,24 @@ ksw_query_t *ksw_qinit(int p, int qlen, const uint8_t *query, int m, const int8_
        q->mdiff += q->shift; // this is the difference between the min and max scores
        // An example: p=8, qlen=19, slen=3 and segmentation:
        //  {{0,3,6,9,12,15,18,-1},{1,4,7,10,13,16,-1,-1},{2,5,8,11,14,17,-1,-1}}
-       t = (int8_t*)q->qp;
-       for (a = 0; a < m; ++a) {
-               int i, k;
-               const int8_t *ma = mat + a * m;
-               for (i = 0; i < slen; ++i)
-                       for (k = i; k < qlen16; k += slen) // p iterations
-                               *t++ = (k >= qlen? 0 : ma[query[k]]) + q->shift;
+       if (size == 1) {
+               int8_t *t = (int8_t*)q->qp;
+               for (a = 0; a < m; ++a) {
+                       int i, k, nlen = slen * p;
+                       const int8_t *ma = mat + a * m;
+                       for (i = 0; i < slen; ++i)
+                               for (k = i; k < nlen; k += slen) // p iterations
+                                       *t++ = (k >= qlen? 0 : ma[query[k]]) + q->shift;
+               }
+       } else {
+               int16_t *t = (int16_t*)q->qp;
+               for (a = 0; a < m; ++a) {
+                       int i, k, nlen = slen * p;
+                       const int8_t *ma = mat + a * m;
+                       for (i = 0; i < slen; ++i)
+                               for (k = i; k < nlen; k += slen) // p iterations
+                                       *t++ = (k >= qlen? 0 : ma[query[k]]) + q->shift;
+               }
        }
        return q;
 }
@@ -200,6 +211,108 @@ end_loop:
        return a->score;
 }
 
+int ksw_sse2_8(ksw_query_t *q, int tlen, const uint8_t *target, ksw_aux_t *a) // the first gap costs -(_o+_e)
+{
+       int slen, i, m_b, n_b, te = -1, gmax = 0;
+       uint64_t *b;
+       __m128i zero, gapoe, shift, gape, *H0, *H1, *E, *Hmax;
+
+#define __max_8(ret, xx) do { \
+               (xx) = _mm_max_epi16((xx), _mm_srli_si128((xx), 8)); \
+               (xx) = _mm_max_epi16((xx), _mm_srli_si128((xx), 4)); \
+               (xx) = _mm_max_epi16((xx), _mm_srli_si128((xx), 2)); \
+       (ret) = _mm_extract_epi16((xx), 0); \
+       } while (0)
+
+       // initialization
+       m_b = n_b = 0; b = 0;
+       zero = _mm_set1_epi32(0);
+       gapoe = _mm_set1_epi16(a->gapo + a->gape);
+       gape = _mm_set1_epi16(a->gape);
+       shift = _mm_set1_epi16(q->shift);
+       H0 = q->H0; H1 = q->H1; E = q->E; Hmax = q->Hmax;
+       slen = q->slen;
+       for (i = 0; i < slen; ++i) {
+               _mm_store_si128(E + i, zero);
+               _mm_store_si128(H0 + i, zero);
+               _mm_store_si128(Hmax + i, zero);
+       }
+       // the core loop
+       for (i = 0; i < tlen; ++i) {
+               int j, k, imax;
+               __m128i e, h, f = zero, max = zero, *S = q->qp + target[i] * slen; // s is the 1st score vector
+               h = _mm_load_si128(H0 + slen - 1); // h={2,5,8,11,14,17,-1,-1} in the above example
+               h = _mm_slli_si128(h, 2);
+               for (j = 0; LIKELY(j < slen); ++j) {
+                       h = _mm_adds_epu16(h, *S++);
+                       h = _mm_subs_epu16(h, shift);
+                       //int k;for (k=0;k<16;++k)printf("%d ", ((int16_t*)&h)[k]);printf("\n");
+                       e = _mm_load_si128(E + j);
+                       h = _mm_max_epi16(h, e);
+                       h = _mm_max_epi16(h, f);
+                       max = _mm_max_epi16(max, h);
+                       _mm_store_si128(H1 + j, h);
+                       h = _mm_subs_epu16(h, gapoe);
+                       e = _mm_subs_epu16(e, gape);
+                       e = _mm_max_epi16(e, h);
+                       _mm_store_si128(E + j, e);
+                       f = _mm_subs_epu16(f, gape);
+                       f = _mm_max_epi16(f, h);
+                       h = _mm_load_si128(H0 + j);
+               }
+               for (k = 0; LIKELY(k < 16); ++k) {
+                       f = _mm_slli_si128(f, 2);
+                       for (j = 0; LIKELY(j < slen); ++j) {
+                               h = _mm_load_si128(H1 + j);
+                               h = _mm_max_epi16(h, f);
+                               _mm_store_si128(H1 + j, h);
+                               h = _mm_subs_epu16(h, gapoe);
+                               f = _mm_subs_epu16(f, gape);
+                               if(UNLIKELY(!_mm_movemask_epi8(_mm_cmpgt_epi16(f, h)))) goto end_loop8;
+                       }
+               }
+end_loop8:
+               __max_8(imax, max);
+               if (imax >= a->T) {
+                       if (n_b == 0 || (int32_t)b[n_b-1] + 1 != i) {
+                               if (n_b == m_b) {
+                                       m_b = m_b? m_b<<1 : 8;
+                                       b = realloc(b, 8 * m_b);
+                               }
+                               b[n_b++] = (uint64_t)imax<<32 | i;
+                       } else if ((int)(b[n_b-1]>>32) < imax) b[n_b-1] = (uint64_t)imax<<32 | i; // modify the last
+               }
+               if (imax > gmax) {
+                       gmax = imax; te = i;
+                       for (j = 0; LIKELY(j < slen); ++j)
+                               _mm_store_si128(Hmax + j, _mm_load_si128(H1 + j));
+               }
+               S = H1; H1 = H0; H0 = S;
+       }
+       a->score = gmax; a->te = te;
+       {
+               int max = -1, low, high;
+               uint16_t *t = (uint16_t*)Hmax;
+               for (i = 0, a->qe = -1; i < q->qlen; ++i, ++t)
+                       if ((int)*t > max) max = *t, a->qe = i / 8 + i % 8 * slen;
+               i = (a->score + q->max - 1) / q->max;
+               low = te - i; high = te + i;
+               for (i = 0, a->score2 = 0; i < n_b; ++i) {
+                       int e = (int32_t)b[i];
+                       if ((e < low || e > high) && b[i]>>32 > (uint32_t)a->score2)
+                               a->score2 = b[i]>>32, a->te2 = e;
+               }
+       }
+       free(b);
+       return a->score;
+}
+
+int ksw_sse2(ksw_query_t *q, int tlen, const uint8_t *target, ksw_aux_t *a)
+{
+       if (q->size == 1) return ksw_sse2_16(q, tlen, target, a);
+       else return ksw_sse2_8(q, tlen, target, a);
+}
+
 /*******************************************
  * Main function (not compiled by default) *
  *******************************************/
@@ -233,14 +346,14 @@ unsigned char seq_nt4_table[256] = {
 
 int main(int argc, char *argv[])
 {
-       int c, sa = 1, sb = 3, i, j, k, forward_only = 0;
+       int c, sa = 1, sb = 3, i, j, k, forward_only = 0, size = 2;
        int8_t mat[25];
        ksw_aux_t a;
        gzFile fpt, fpq;
        kseq_t *kst, *ksq;
        // parse command line
        a.gapo = 5; a.gape = 2; a.T = 10;
-       while ((c = getopt(argc, argv, "a:b:q:r:ft:")) >= 0) {
+       while ((c = getopt(argc, argv, "a:b:q:r:ft:s:")) >= 0) {
                switch (c) {
                        case 'a': sa = atoi(optarg); break;
                        case 'b': sb = atoi(optarg); break;
@@ -248,6 +361,7 @@ int main(int argc, char *argv[])
                        case 'r': a.gape = atoi(optarg); break;
                        case 't': a.T = atoi(optarg); break;
                        case 'f': forward_only = 1; break;
+                       case 's': size = atoi(optarg); break;
                }
        }
        if (optind + 2 > argc) {
@@ -268,7 +382,7 @@ int main(int argc, char *argv[])
        while (kseq_read(ksq) > 0) {
                ksw_query_t *q[2];
                for (i = 0; i < ksq->seq.l; ++i) ksq->seq.s[i] = seq_nt4_table[(int)ksq->seq.s[i]];
-               q[0] = ksw_qinit(16, ksq->seq.l, (uint8_t*)ksq->seq.s, 5, mat);
+               q[0] = ksw_qinit(size, ksq->seq.l, (uint8_t*)ksq->seq.s, 5, mat);
                if (!forward_only) { // reverse
                        for (i = 0; i < ksq->seq.l/2; ++i) {
                                int t = ksq->seq.s[i];
@@ -277,16 +391,16 @@ int main(int argc, char *argv[])
                        }
                        for (i = 0; i < ksq->seq.l; ++i)
                                ksq->seq.s[i] = ksq->seq.s[i] == 4? 4 : 3 - ksq->seq.s[i];
-                       q[1] = ksw_qinit(16, ksq->seq.l, (uint8_t*)ksq->seq.s, 5, mat);
+                       q[1] = ksw_qinit(size, ksq->seq.l, (uint8_t*)ksq->seq.s, 5, mat);
                } else q[1] = 0;
                gzrewind(fpt); kseq_rewind(kst);
                while (kseq_read(kst) > 0) {
                        int s;
                        for (i = 0; i < kst->seq.l; ++i) kst->seq.s[i] = seq_nt4_table[(int)kst->seq.s[i]];
-                       s = ksw_sse2_16(q[0], kst->seq.l, (uint8_t*)kst->seq.s, &a);
+                       s = ksw_sse2(q[0], kst->seq.l, (uint8_t*)kst->seq.s, &a);
                        printf("%s\t%s\t+\t%d\t%d\t%d\n", ksq->name.s, kst->name.s, s, a.te+1, a.qe+1);
                        if (q[1]) {
-                               s = ksw_sse2_16(q[1], kst->seq.l, (uint8_t*)kst->seq.s, &a);
+                               s = ksw_sse2(q[1], kst->seq.l, (uint8_t*)kst->seq.s, &a);
                                printf("%s\t%s\t-\t%d\t%d\t%d\n", ksq->name.s, kst->name.s, s, a.te+1, a.qe+1);
                        }
                }