#!/usr/bin/env python

# Copyright (C) 2011-2012 CRS4.
#
# This file is part of Seal.
#
# Seal is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Seal is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Seal.  If not, see <http://www.gnu.org/licenses/>.

"""
SNP Realignment tool
====================

Given a tsv file with at least the following columns::

 vid        mask
 V009090320 GGATACATTTTATTGC[A/G]CTTGCAGAGTATTTTT
 V009090320 GGATACATTTTATTGC[A/G]CTTGCAGAGTATTTTT
 ...

and a reference genome (pre-indexed for bwa use), it will generate a
new tsv file with the following columns::

 vid        ref_genome chromosome pos    strand allele copies
 V902439092 hg18       1          493589 +      A      2
 V902439092 hg18       1          493594 +      B      2
 ...

Where strand is the alignment strand, while pos is the position of the
SNP. If there are multiple hits, there will be multiple rows, each of
whom will have the copies column set to the number of hits.

Note that we are not considering secondary alignments: the number of
multiple hits can't be greater than 2, with 2 meaning that both
variants scored a perfect unambiguous hit.

In the case where neither variant can be aligned, a row will be output
with the chromosome, position and number of copies set to 0.
"""

# The current implementation uses a strategy based on expressing
# flanking as a pair of reads. An alternative would be to use direct
# searches for the two possible flanks + allele sequences.

import logging, sys, os, optparse, re, csv
import itertools as it

from bl.lib.seq.aligner.bwa.bwa_aligner import BwaAligner
from bl.lib.tools.standard_monitor import StandardMonitor



def id_mux(vid, allele):
  return vid + '-%s' % allele


def id_demux(ids):
  vid, allele = ids.split('-')
  return vid, allele


class SnpHitProcessor(object):

  def __init__(self, ref_genome_tag, outf=sys.stdout, logger=None):
    self.ref_genome_tag = ref_genome_tag
    self.outf = outf
    if logger is None:
      logger = logging.getLogger("SnpHitProcessor")
      logger.setLevel(logging.DEBUG)
      logger.addHandler(logging.StreamHandler())  # write to stdout
    self.logger = logger
    self.current_id = None
    self.current_hits = []

  def process(self, pair):
    """
    Process a hit pair, looking for a perfect (edit distance, i.e., NM
    tag value == 0) and unambiguous (mapping quality > 0) hit.
    """
    l_hit, r_hit = pair
    name = l_hit.get_name()
    id_, allele = id_demux(name)
    if id_ != self.current_id:
      if self.current_id is not None:
        self.process_current_hits()
      self.current_id = id_
      self.current_hits = []
    nm = l_hit.tag_value('NM')
    seq = l_hit.get_seq_5()
    if nm <= 0 and l_hit.qual > 0:
      snp_pos = l_hit.pos + (len(seq)-1)/2
      chromosome = l_hit.tid or '*'
      strand = '-' if l_hit.is_on_reverse() else '+'
      self.current_hits.append(
        [id_, self.ref_genome_tag, chromosome, str(snp_pos), strand, allele]
        )
    else:
      self.logger.info("%r: NM:%d; qual:%d" % (name, nm, l_hit.qual))

  def process_current_hits(self):
    nh = len(self.current_hits)
    if nh != 1:
      self.logger.error("hit count for %s: %d != 1" % (self.current_id, nh))
    if nh == 0:
      self.current_hits.append(
        [self.current_id, self.ref_genome_tag, '0', '0', '+', 'A']
        )
    for hit in self.current_hits:
      hit.append(str(nh))
      assert hit[0] == self.current_id
      self.outf.write("\t".join(hit)+"\n")

  def close_open_handles(self):
    self.outf.close()


def get_aligner(opt):
  refg_tag = os.path.splitext(os.path.basename(opt.reference))[0]
  outf = open(opt.output_file, 'w')
  outf.write("vid\tref_genome\tchromosome\tpos\tstrand\tallele\tcopies\n")
  logger = logging.getLogger()
  aligner = BwaAligner()
  aligner.event_monitor = StandardMonitor(logger)
  aligner.nthreads = opt.n_threads
  aligner.mmap_enabled = opt.enable_mmap
  aligner.reference = opt.reference
  aligner.hit_visitor = SnpHitProcessor(refg_tag, outf=outf, logger=logger)
  aligner.event_monitor.log_info("mmap_enabled is %s", aligner.mmap_enabled)
  aligner.event_monitor.log_info("using reference at %s", aligner.reference)
  aligner.event_monitor.log_info("Reading pairs")
  return aligner


def make_pairs(vid, lflank, rflank, alleles):
  a_seq = lflank + alleles[0] + rflank
  b_seq = lflank + alleles[1] + rflank
  return [
    (id_mux(vid, 'A'), a_seq, 'E'*len(a_seq), a_seq, 'E'*len(a_seq)),
    (id_mux(vid, 'B'), b_seq, 'E'*len(b_seq), b_seq, 'E'*len(b_seq)),
    ]


def load_batch(aligner, batch_size, i_tsv):
  logger = logging.getLogger()
  for r in it.islice(i_tsv, batch_size):
    m = re.match('^([ACGT]+)\[([ACGT])/([ACGT])\]([ACGT]+)$', r['mask'],
                 flags=re.IGNORECASE)
    if m:
      lflank, alleles, rflank = (m.group(1), (m.group(2), m.group(3)),
                                 m.group(4))
      for p in make_pairs(r['vid'], lflank, rflank, alleles):
        aligner.load_pair_record(p)
    else:
      logger.error('cannot parse %s' % r['vid'])


def process(opt):
  with open(opt.input_file) as f:
    i_tsv = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
    batch_size = opt.batch_size
    aligner = get_aligner(opt)
    load_batch(aligner, batch_size, i_tsv)
    while aligner.get_batch_size() > 0:
      aligner.run_alignment()
      aligner.clear_batch()
      load_batch(aligner, batch_size, i_tsv)
  aligner.hit_visitor.process_current_hits()  # last pair
  aligner.hit_visitor.close_open_handles()


class HelpFormatter(optparse.IndentedHelpFormatter):
  def format_description(self, description):
    return description + "\n" if description else ""


def make_parser():
  parser = optparse.OptionParser(
    usage="%prog [OPTIONS] --reference=REFERENCE  -i INTSV -o OUTTSV",
    formatter=HelpFormatter(),
    )
  parser.add_option("-r", "--reference", type="str", metavar="STRING",
                    help="reference indices")
  parser.add_option("-n", "--n-threads", type="int", metavar="INT",
                    default=1,
                    help="number of threads [1]")
  parser.add_option("-b", "--batch-size", type="int", metavar="INT",
                    default=40000,
                    help="how many rows will be read in one batch [40000]")
  parser.add_option("--enable-mmap", action="store_true",
                    default=False,
                    help="enable memory mapping [False]")
  parser.add_option("-o", "--output-file", type="str", metavar="FILE",
                    help="output tsv file")
  parser.add_option("-i", "--input-file", type="str", metavar="FILE",
                    help="input tsv file")
  return parser


def main():
  parser = make_parser()
  opt, args = parser.parse_args()
  if not (opt.reference and opt.input_file and opt.output_file):
      parser.print_help()
      sys.exit(2)
  logging.basicConfig(level=logging.DEBUG)
  process(opt)


if __name__ == "__main__":
    main()


# Local Variables: **
# mode: python **
# End: **
