1 module bio.std.hts.snpcallers.maq;
2 
3 /*
4  * The code below is based on errmod.c from Samtools.
5  */
6 
7 import core.stdc.math;
8 import std.math : LN2, LN10, isNaN;
9 import std.traits;
10 import std.range;
11 import std.algorithm;
12 import std.random;
13 import std.typecons;
14 
15 import bio.std.hts.bam.md.reconstruct;
16 import bio.std.hts.bam.pileup;
17 
18 import bio.core.base;
19 import bio.core.genotype;
20 import bio.core.call;
21 import bio.core.tinymap;
22 
23 struct BaseWithStrand {
24     immutable ValueSetSize = Base.ValueSetSize * 2;
25     private ubyte _code;
26     ubyte internal_code() @property const {
27         return _code;
28     }
29 
30     static BaseWithStrand fromInternalCode(ubyte code) {
31         BaseWithStrand bws = void;
32         bws._code = code;
33         return bws;
34     }
35 
36     this(Base b, bool is_reverse) {
37         _code = cast(ubyte)(b.internal_code * 2 + (is_reverse ? 1 : 0));
38     }
39 
40     Base base() @property const {
41         return Base.fromInternalCode(_code / 2);
42     }
43 
44     bool is_reverse_strand() @property const {
45         return (_code & 1) == 1;
46     }
47 }
48 
49 struct ReadBase {
50     BaseWithStrand base_with_strand;
51     alias base_with_strand this;
52     private ubyte _quality;
53 
54     this(Base b, ubyte quality, bool is_reverse) {
55         base_with_strand = BaseWithStrand(b, is_reverse);
56         _quality = quality;
57     }
58 
59     ubyte quality() @property const {
60         return _quality;
61     }
62 }
63 
64 struct ErrorModelCoefficients {
65     private {
66 
67         // _fk[n] = (1 - depcorr)^n * (1 - eta) + eta
68         double[] _fk;
69 
70         // _beta[q << 16 | n << 8 | k ] = see MAQ paper for meaning of \beta
71         double[] _beta;
72 
73         // _lhet[n << 8 | k] = log(1/2^n * choose(n, k))
74         double[] _lhet;
75         
76         immutable Base[4] nucleotides = [Base('A'), Base('C'), Base('G'), Base('T')];
77     }
78 
79     this(double depcorr, double eta) {
80         _fk.length = 256;
81         _beta.length = 256 * 256 * 64;
82         _lhet.length = 256 * 256;
83 
84         foreach (n, ref v; _fk) {
85             v = core.stdc.math.pow(1.0 - depcorr, cast(double)n) * (1.0 - eta) + eta;
86         }
87 
88         // lC[n][k] = log(choose(n, k))
89         double[256][256] lC;
90 
91         // lG[n] = logGamma(n + 1)
92         double[256] lG;
93 
94         for (size_t n = 0; n <= 255; ++n) {
95             lG[n] = core.stdc.math.lgamma(cast(double)(n + 1));
96             for (size_t k = 0; k <= n / 2; ++k) {
97                 lC[n][n-k] = lC[n][k] = lG[n] - lG[k] - lG[n-k];
98 
99                 // fill _lhet simultaneously
100                 _lhet[n << 8 | (n-k)] = _lhet[n << 8 | k] = lC[n][k] - n * cast(double)LN2;
101             }
102         }
103 
104         for (size_t q = 1; q < 64; ++q) {
105             real e = 10.0 ^^ (-(cast(real)q) / 10.0);
106             real le = core.stdc.math.logl(e);
107             real le1 = core.stdc.math.logl(1.0 - e);
108 
109             for (int n = 1; n <= 255; ++n) {
110                 real sum, sum1;
111                 sum = sum1 = 0.0;
112                 for (int k = n; k >= 0; --k) {
113                     sum = sum1 + core.stdc.math.expl(lC[n][k] + k * le + (n-k) * le1);
114                     _beta[q << 16 | n << 8 | k] = -10.0 / LN10 * core.stdc.math.logl(sum1 / sum);
115                     sum1 = sum;
116                 }
117             }
118         }
119     }
120 
121     double fk(size_t n) const {
122         return _fk[n];
123     }
124 
125     double beta(uint quality, size_t n, size_t k) const {
126         return _beta[quality << 16 | n << 8 | k];
127     }
128 
129     double lhet(size_t n, size_t k) const {
130         return _lhet[n << 8 | k];
131     }
132 
133     alias TinyMap!(DiploidGenotype!Base5, float, useDefaultValue) Dict;
134 
135     private immutable C = 10.0 / LN10;
136 
137     Dict computeLikelihoods(R)(R read_bases, bool symmetric=false) const
138         if (is(ElementType!R == ReadBase) && hasLength!R) 
139     {
140         // if there're more than 255 reads, subsample them
141         ReadBase[255] buf = void;
142         if (read_bases.length > buf.length) {
143             copy(randomSample(read_bases, buf.length), buf[]);
144         } else {
145             copy(read_bases, buf[]);
146         }
147         auto bases = buf[0 .. min(read_bases.length, $)];
148 
149         sort!"a.quality < b.quality"(bases);
150 
151         auto w = TinyMap!(BaseWithStrand, uint, fillNoRemove)(0);
152         auto c = TinyMap!(Base, uint, fillNoRemove)(0);
153         auto fsum = TinyMap!(Base, double, fillNoRemove)(0.0);
154         auto bsum = TinyMap!(Base, double, fillNoRemove)(0.0);
155 
156         foreach_reverse (ref read_base; bases) {
157             auto quality = read_base.quality;
158             if (quality < 4) quality = 4;
159             if (quality > 63) quality = 63;
160            
161             auto bws = read_base.base_with_strand;
162             auto b = bws.base;
163 
164             fsum[b] += fk(w[bws]);
165             bsum[b] += fk(w[bws]) * beta(quality, bases.length, c[b]);
166             c[b] += 1;
167             w[bws] += 1;
168         }
169 
170         alias diploidGenotype dG;
171 
172         auto q = Dict(float.min);
173 
174         foreach (i, b1; nucleotides) {
175             float tmp1 = 0.0;
176             int tmp2;
177             float tmp3 = 0.0;
178 
179             // homozygous
180             foreach (k, b2; nucleotides) {
181                 if (k != i) {
182                     tmp1 += bsum[b2];
183                     tmp2 += c[b2];
184                     tmp3 += fsum[b2];
185                 }
186             }
187 
188             auto b1_5 = cast(Base5)b1;
189             if (tmp2 > 0) {
190                 q[dG(b1_5)] = tmp1;
191             } else {
192                 q[dG(b1_5)] = 0.0;
193             }
194 
195             // heterozygous
196             for (size_t j = i + 1; j < nucleotides.length; ++j) {
197                 auto b2 = nucleotides[j];
198                 int cij = c[b1] + c[b2];
199                 tmp1 = tmp3 = 0.0;
200                 tmp2 = 0;
201                 foreach (k, b3; nucleotides) {
202                     if (k != i && k != j) {
203                         tmp1 += bsum[b3];
204                         tmp2 += c[b3];
205                         tmp3 += fsum[b3];
206                     }
207                 }
208 
209                 auto b2_5 = cast(Base5)b2;
210                 if (tmp2 > 0) {
211                     q[dG(b2_5, b1_5)] = tmp1 - C * lhet(cij, c[b2]);
212                 } else {
213                     q[dG(b2_5, b1_5)] = -C * lhet(cij, c[b2]);
214                 }
215 
216                 if (symmetric) {
217                     q[dG(b1_5, b2_5)] = q[dG(b2_5, b1_5)];
218                 }
219             }
220 
221             foreach (k, b2; nucleotides) {
222                 auto g = dG(b1_5, cast(Base5)b2);
223                 if (g in q) {
224                     if (q[g] < 0.0) q[g] = 0.0;
225                 }
226             }
227         }
228 
229         return q;
230     }
231 }
232 
233 // Encapsulates information about genotype likelihoods at a site.
234 struct GenotypeLikelihoodInfo {
235 
236     alias ErrorModelCoefficients.Dict ScoreDict;
237 
238     alias DiploidGenotype!Base5 Gt;
239 
240     this(ScoreDict dict) {
241 
242         _dict = dict;
243         size_t k = 0;
244 
245         // copy all data into a buffer, combining that with insertion sort
246         foreach (gt, score; _dict) {
247             if (k == 0) {
248                 gt_buf[k++] = gt;
249             } else {
250                 size_t j = k;
251                 while (j > 0 && _dict[gt_buf[j-1]] > score) {
252                     gt_buf[j] = gt_buf[j-1];
253                     --j;
254                 }
255                 gt_buf[j] = gt;
256                 ++k;
257             }
258         }
259 
260         assert(k >= 2);
261 
262         _count = cast(ubyte)k;
263     }
264 
265     size_t count() @property const {
266         return _count;
267     }
268 
269     static struct GtInfo {
270         private {
271             Gt _gt;
272             float _prob;
273         } 
274 
275         Gt genotype() @property const {
276             return _gt;
277         }
278 
279         float score() @property const {
280             return _prob;
281         }
282     }
283 
284     GtInfo opIndex(size_t index) {
285         assert(index < count);
286         auto gt = gt_buf[index];
287         return GtInfo(gt, _dict[gt]);
288     }
289 
290     private Gt[25] gt_buf;
291     private ubyte _count;
292     private ScoreDict _dict;
293 }
294 
295 class ErrorModel {
296     
297     private {
298         float _depcorr;
299         float _eta;
300         ErrorModelCoefficients _coef;
301     }
302 
303     this(float depcorr, float eta=0.03) {
304         _depcorr = depcorr;
305         _eta = eta;
306         _coef = ErrorModelCoefficients(_depcorr, _eta);
307     }
308 
309     const(ErrorModelCoefficients) coefficients() @property const {
310         return _coef;
311     }
312 
313     alias coefficients this;
314 }
315 
316 /// Class for calling SNPs using MAQ model.
317 ///
318 /// Typical usage:
319 ///     auto caller = new MaqSnpCaller();
320 ///     caller.minimum_call_quality = 20.0f;
321 ///     caller.minimum_base_quality = 13;
322 ///     foreach (snp; caller.findSNPs(reads)) { ... }
323 ///
324 final class MaqSnpCaller {
325     
326     private float _depcorr = 0.17;
327     private float _eta = 0.03;
328     private float _minimum_call_quality = 6.0;
329     private ubyte _minimum_base_quality = 13;
330     private bool _need_to_recompute_errmod = true;
331 
332     ///
333     float depcorr() @property const {
334         return _depcorr;
335     }
336 
337     /// ditto
338     void depcorr(float f) @property {
339         _depcorr = f;
340         _need_to_recompute_errmod = true;
341     }
342 
343     ///
344     float eta() @property const {
345         return _eta;
346     }
347 
348     ///
349     void eta(float f) @property {
350         _eta = f;
351         _need_to_recompute_errmod = true;
352     }
353     
354     /// Minimum call quality
355     float minimum_call_quality() @property const {
356         return _minimum_call_quality;
357     }
358 
359     /// ditto
360     void minimum_call_quality(float f) @property {
361         _minimum_call_quality = f;
362     }
363 
364     /// Discard reads with base quality less than this at a site
365     ubyte minimum_base_quality() @property const {
366         return _minimum_base_quality;
367     }
368 
369     void minimum_base_quality(ubyte q) @property {
370         _minimum_base_quality = q;
371     }
372 
373     ErrorModel errmod() @property {
374         if (_need_to_recompute_errmod) {
375             synchronized {
376                 if (_need_to_recompute_errmod) {
377                     _errmod = new ErrorModel(_depcorr, _eta);
378                     _need_to_recompute_errmod = false;
379                 }
380             }
381         }
382         return _errmod;
383     }
384 
385     private ErrorModel _errmod;
386 
387     /// Get genotype likelihoods
388     final GenotypeLikelihoodInfo genotypeLikelihoodInfo(C)(C column) {
389 
390         version(MaqCaller8192) {
391             ReadBase[8192] buf = void;
392         }
393 
394         size_t num_of_valid_bases = 0;
395 
396         foreach (read; column.reads) {
397 
398             version(MaqCaller8192) {
399                 if (num_of_valid_bases == 8192) break;
400             }
401 
402             if (read.current_base_quality < minimum_base_quality)
403                 continue;
404             if (read.current_base == '-')
405                 continue;
406 
407             version(MaqCaller8192) {
408                 buf[num_of_valid_bases] = ReadBase(Base(read.current_base),
409                                                    min(read.current_base_quality, read.mapping_quality),
410                                                    read.is_reverse_strand);
411             }
412 
413             num_of_valid_bases++;
414         }
415 
416         static struct ReadBaseRange(R) {
417             private R _reads = void;
418             private ubyte minimum_base_quality = void;
419                                               
420             this(R reads, ubyte minbq) { 
421                 _reads = reads; minimum_base_quality = minbq; _findNextValid();
422             }
423 
424             ReadBase front() @property { 
425                 auto read = _reads.front;
426                 return ReadBase(Base(read.current_base), 
427                                 min(read.current_base_quality, read.mapping_quality),
428                                 read.is_reverse_strand);
429             }
430             bool empty() @property { return _reads.empty; }
431             void popFront() { _reads.popFront(); _findNextValid(); }
432             ReadBaseRange save() @property { return ReadBaseRange!R(_reads, minimum_base_quality); }
433 
434             private void _findNextValid() {
435                 while (!_reads.empty && 
436                         (_reads.front.current_base_quality < minimum_base_quality ||
437                          _reads.front.current_base == '-')) 
438                 {
439                     _reads.popFront();
440                 }
441             }
442         }
443 
444         if (num_of_valid_bases == 0) {
445             GenotypeLikelihoodInfo result;
446             return result;
447         }
448 
449         version(MaqCaller8192) {
450             ReadBase[] rbs = buf[0 .. num_of_valid_bases];
451             auto likelihood_dict = errmod.computeLikelihoods(rbs);
452         } else {
453             auto rbs = ReadBaseRange!(typeof(column.reads))(column.reads, minimum_base_quality);
454             auto likelihood_dict = errmod.computeLikelihoods(takeExactly(rbs, num_of_valid_bases));
455         }
456         return GenotypeLikelihoodInfo(likelihood_dict);
457     }
458 
459     /// Make call on a pileup column
460     final Nullable!DiploidCall5 makeCall(C)(C column, string reference="", string sample="") {
461 
462         auto gts = genotypeLikelihoodInfo(column);
463 
464         Nullable!DiploidCall5 result;
465 
466         if (gts.count < 2) return result;
467 
468         static if (__traits(compiles, column.reference_base)) {
469             auto refbase = Base5(column.reference_base);
470         } else {
471             auto refbase = Base5('N');
472         }
473         
474         if (sample == "") {
475             auto rg = column.reads.front["RG"];
476             if (!rg.is_nothing) {
477                 sample = cast(string)rg;
478             }
479         }
480 
481         result = DiploidCall5(sample, reference, column.position,
482                               refbase, gts[0].genotype,
483                               gts[1].score - gts[0].score);
484                 
485         return result;
486     }
487 
488     /// main method of this class
489     auto findSNPs(P)(P pileup_columns, string reference="", string sample="") {
490         static assert(__traits(compiles, {pileup_columns.front.reference_base;}));
491 
492         static struct Result {
493             private MaqSnpCaller _caller;
494             private P _pileup;
495             private DiploidCall5 _front;
496             private bool _empty;
497             private string _reference;
498             private string _sample;
499 
500             this(MaqSnpCaller caller, P pileup, string reference, string sample) {
501                 _caller = caller;
502                 _pileup = pileup;
503                 _reference = reference;
504                 _sample = sample;
505                 _fetchNextSNP();
506             }
507 
508             DiploidCall5 front() @property {
509                 return _front;
510             }
511            
512             bool empty() @property {
513                 return _empty;
514             }
515 
516             void popFront() {
517                 _pileup.popFront();
518                 _fetchNextSNP();
519             }
520 
521             private void _fetchNextSNP() {
522                 while (true) {
523                     if (_pileup.empty) {
524                         _empty = true;
525                         break;
526                     }
527 
528                     auto call = _caller.makeCall(_pileup.front, _reference, _sample);
529                     if (!call.isNull && call.is_variant && call.quality > _caller.minimum_call_quality) {
530                         _front = call.get;
531                         break;
532                     } else {
533                         _pileup.popFront();
534                     }
535                 }
536             }
537         }
538 
539         return Result(this, pileup_columns, reference, sample);
540     }
541 }