#!/usr/bin/env python

# Copyright (C) 2011-2012 CRS4.
#
# This file is part of Seal.
#
# Seal is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Seal is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Seal.  If not, see <http://www.gnu.org/licenses/>.



import os
import sys

import fileinput
import logging
import optparse

from bl.lib.seq.aligner.bwa.bwa_aligner import BwaAligner
from bl.lib.seq.aligner.io.sam_formatter import SamFormatter
from bl.lib.tools.standard_monitor import StandardMonitor

from bl.mr.lib.hit_processor_chain_link import HitProcessorChainLink
from bl.mr.lib.filter_link import FilterLink
import bl.lib.seq.aligner.io.protobuf_mapping as protobuf_mapping

"""
An example align script
=======================

This script is can be used to run align read pairs with libbwa from
the command line.

The expected input line format is the following tab-separated fields::
 <id> <read 1 seq> <read 1 qualseq> <read 2 seq> <read 2 qualseq>

The script aligns each input pair against each of the given genome references
and outputs all the mappings returned by bwa.
"""

class SamHitProcessor(object):
  def __init__(self):
    self.output_formatter = SamFormatter(strip_pe_tag=True)

  def process(self, pair):
    for h in pair:
      if h.is_mapped(): # output all mapped hits
        print "%s\tXFlag:A:%s" % (self.output_formatter.format(h), h.flag_string())

class SimplifiedMarkDuplicatesEmitter(HitProcessorChainLink):
  """like the one in seqal mapper, but without the Hadoop context"""
  def __init__(self, event_monitor, next_link = None):
    super(type(self), self).__init__(next_link)
    self.event_monitor = event_monitor

  def process(self, pair):
    record = protobuf_mapping.serialize_pair(pair)
    for hit in pair:
      if hit and hit.is_mapped():
        key = ':'.join((hit.tid, str(hit.get_untrimmed_pos()), 'R' if hit.is_on_reverse() else 'F'))
        #self.ctx.emit(key, record)
        self.event_monitor.count("emitted_coordinates", 1)

    if self.next_link:
      self.next_link.process(pair)

class HelpFormatter(optparse.IndentedHelpFormatter):
  def format_description(self, description):
    return description + "\n" if description else ""

def make_parser():
  parser = optparse.OptionParser(
    usage="%prog [OPTIONS] --reference=REFERENCE IPAIRS IPAIRS..",
    formatter=HelpFormatter(),
    )
  parser.add_option("-r", "--reference", type="str", metavar="STRING",
                    help="reference indices")
  parser.add_option("-n", "--n-threads", type="int", metavar="INT",
                    default=1,
                    help="number of threads [1]")
  parser.add_option("--enable-mmap", action="store_true",
                    default=True,
                    help="enable memory mapping [True]")
  return parser

def do_alignment(aligner):
	aligner.event_monitor.log_info("Aligning %d" % aligner.get_batch_size())
	aligner.run_alignment()
	aligner.clear_batch()

def main(argv):
  parser = make_parser()
  opt, args = parser.parse_args()

  #--
  logging.basicConfig(level=logging.DEBUG)
  logger = logging.getLogger()
  #--

  if opt.reference is None:
		print >>sys.stderr, "Error:  you must provide a reference"
		parser.print_help()
		sys.exit(1)

  align = BwaAligner()
  align.event_monitor = StandardMonitor(logger)
  align.nthreads     = opt.n_threads
  align.mmap_enabled = opt.enable_mmap
  align.reference    = opt.reference

  #filter_link = FilterLink(align.event_monitor)
  #filter_link.set_next(SimplifiedMarkDuplicatesEmitter(align.event_monitor))
  #align.hit_visitor = filter_link
  align.hit_visitor = SamHitProcessor()

  align.event_monitor.log_info("mmap_enabled is %s", align.mmap_enabled)
  align.event_monitor.log_info("using reference at %s", align.reference)

  align.event_monitor.log_info("Reading pairs")
  for line in fileinput.input(args):
    if not line.startswith("#"):
      align.load_pair_record(line.rstrip("\r\n").split("\t"))
    if align.get_batch_size() >= 30000: # BATCH SIZE
      do_alignment(align)

  if align.get_batch_size() >= 1:
    do_alignment(align)

  align.event_monitor.log_info("All finished")
  for counter, value in align.event_monitor.each_counter():
    align.event_monitor.log_info("counted %s: %d", counter, value)
  align.event_monitor.log_info(("memory consumption:\n%s" %
                                os.popen("grep '^Vm' /proc/%d/status"
                                         % os.getpid(), "r").read()))
  align.release_resources()

if __name__ == "__main__":
  main(sys.argv)
