#! /usr/bin/env python2.7
#-*- encoding: utf-8 -*-
#   Copyright (C) 2011-2013 by
#   Nicholas Mancuso <nick.mancuso@gmail.com>
#   All rights reserved.
#   BSD license.

import argparse
import itertools as itools
import logging
import os
import subprocess as sub
import sys
import bioa
import pysam

from multiprocessing import cpu_count
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

BAM_FILE = "rb"
SAM_FILE = "r"

INDEL_PATH = "INDELFIXER_PATH"
KGEM_PATH = "KGEM_PATH"


class readable_dir(argparse.Action):
    """ Make a readable directory class action for the arg parser.
    """
    def __call__(self,parser, namespace, values, option_string=None):
        prospective_dir=values
        if not os.path.isdir(prospective_dir):
            raise argparse.ArgumentTypeError(
                    "readable_dir:{0} is not a valid path".format(prospective_dir))
        if os.access(prospective_dir, os.R_OK):
            setattr(namespace,self.dest,prospective_dir)
        else:
            raise argparse.ArgumentTypeError(
                    "readable_dir:{0} is not a readable dir".format(prospective_dir))


def seq_func(idx, read):
    """ Helper function for output
    """
    return SeqRecord(Seq(read), id="Read{0}".format(idx))


def call_indelfixer(tech, read_file_path, ref_file_path, refine, output_dir, indel_path):
    tech = "-{0}".format(tech) if tech else ""
    refine = ["-refine", str(refine)] if refine > 0 else []
    indelargs = ["-jar", "{0}InDelFixer.jar".format(indel_path),
                 "-i", read_file_path,
                 "-g", ref_file_path,
                 tech] + refine + ["-o", output_dir]
    sub.check_call(["java"] + indelargs)
    return


def call_samtools(ref_file_path, output_dir):
    samtoolsargs_1 = ["view",
                      "-bt", ref_file_path,
                      "{0}reads.sam".format(output_dir),
                      "-o", "{0}reads.bam".format(output_dir)]
    samtoolsargs_2 = ["sort",
                      "{0}reads.bam".format(output_dir),
                      "{0}reads_sorted".format(output_dir)]
    sub.check_call(["samtools"] + samtoolsargs_1)
    sub.check_call(["samtools"] + samtoolsargs_2)
    return


def call_kgem(out_amp_path, in_amp_path, kgem_path):
    kgem_args = ["-jar", "{0}KGEM.jar".format(kgem_path),
                 "-tr", "2",
                 out_amp_path,
                 "-o", in_amp_path]
    sub.check_call(["java"] + kgem_args)
    return


def parse_args(args):
    """ Build the argument-parser and return parsed output.
    """
    argp = argparse.ArgumentParser(description="Infer viral quasispecies " +
            "from amplicon sequence data.")
    argp.add_argument("read_file", type=argparse.FileType("r"),
            help="Path to the FASTA file.")
    argp.add_argument("ref_file", type=argparse.FileType("r"),
            help="Reference sequence file.")
    argp.add_argument("primers", type=argparse.FileType("r"),
            help="FASTA file containing primers for each amplicon." +
            " Order should be in amplicon coverage order: " +
            "forward primer then backward primer per amplicon.")
    argp.add_argument("-s", "--seq", choices=["454", "illumina", "pacbio"],
            help="Technology used to produce the read data." +
            " Leave blank if tech used is not one of these three." +
            " Default is blank.", default="")
    argp.add_argument("-k", "--flows", type=int, default=0, required=False,
            help="Number of commodities to use for MCF formulation." +
            " Default is 0, which indicates to use Max Bandwidth heuristic." +
            " Note: k > 0 requires CPLEX and CPLEX python module to be installed.")
    argp.add_argument("-m", "--num_mismatch", type=int, default=0,
            required=False, help="Number of allowed mismatches in read overlap." +
            " Default is 0.")
    argp.add_argument("-n", "--num_threads", type=int, default=cpu_count(),
            required=False, help="Number of threads for CPLEX in MCF formulation." +
                    " Default is total number of CPUs.")
    argp.add_argument("-t", "--timeout", type=int, default=600, required=False,
            help="Timeout in seconds until CPLEX must return an answer." +
            " This is used only when k > 0. Default is 600 seconds.")
    argp.add_argument("-o", "--output_dir", action=readable_dir, required=False,
            default=os.getcwd(),
            help="Directory to output the results. Default is current directory.")

    return argp.parse_args(args)


def main(args):
    # Grab and parse the arguments.
    args = parse_args(args)

    # Do this once here for usability/readability.
    output_dir = "{0}{1}".format(args.output_dir, os.sep)

    # Set up the log.
    log_name = output_dir + "vira.log"
    logging.basicConfig(filename=log_name, level=logging.DEBUG,
            format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")

    # Try to get paths for other tools.
    indel_path = os.environ.get(INDEL_PATH, os.getcwd()) + os.sep
    kgem_path = os.environ.get(KGEM_PATH, os.getcwd()) + os.sep

    # Praser primers
    primers = list(SeqIO.parse(args.primers, "fasta"))

    # Do a quick sanity check for flow number.
    if args.flows < 0:
        log_str = "Flow number cannot be negative!"
        logging.error(log_str)
        sys.stderr.write(log_str + os.linesep)
        return 1

    # Start alignment.
    log_str = "Beginning alignment of reads."
    logging.info(log_str)
    try:
        call_indelfixer(args.seq, args.read_file.name, args.ref_file.name,
                1, output_dir, indel_path)
    except sub.CalledProcessError:
        log_str = "Cannot call InDelFixer!"
        log_str = log_str.format(args.ref_file.name)
        sys.stderr.write(log_str + os.linesep)
        logging.error(log_str)
        return 1
    log_str = "Finished alignment of reads."
    logging.info(log_str)

    # Convert to BAM and sort aligned reads.
    log_str = "Converting aligned output to sorted BAM file."
    logging.info(log_str)
    try:
        call_samtools(args.ref_file.name, output_dir)
    except sub.CalledProcessError:
        log_str = "Cannot call samtools!"
        log_str = log_str.format(args.ref_file.name)
        sys.stderr.write(log_str + os.linesep)
        logging.error(log_str)
        return 1
    log_str = "Finished output conversion."
    logging.info(log_str)

    refs = [seq.seq for seq in SeqIO.parse(args.ref_file, "fasta")]
    if not refs:
        log_str = "Reference file {0} must contain a reference!"
        log_str = log_str.format(args.ref_file.name)
        sys.stderr.write(log_str + os.linesep)
        logging.error(log_str)
        return 1

    reference = refs[0]
    try:
        # Load new sorted BAM file.
        sam_path = "{0}reads_sorted.bam".format(output_dir)
        sam_file = pysam.Samfile(sam_path, "rb")

        # Try to do some naive indel adjustment.
        log_str = "Beginning read normalization process on {0}.".format(sam_path)
        logging.info(log_str)
        new_consensus = "{0}consensus.fasta".format(output_dir)
        reference = list(SeqIO.parse(new_consensus, "fasta"))[0]
        reads = bioa.reads_from_sam_quick(sam_file)
        log_str = "Finished read normalization process on {0}.".format(sam_path)
        logging.info(log_str)


        # Parse the FASTA file containing the primers and then find in reference.
        log_str = "Using primers to find amplicon boundaries."
        logging.info(log_str)
        pairs = bioa.find_amplicons_by_primers(reference, primers)
        amplicons = bioa.ReadBuckets(reads, pairs)

        if not amplicons:
            log_str = "Estimation of amplicons failed! Are primers correct?"
            raise Exception(log_str)

        logging.info("Finished read partitioning process.")

        # Write out amplicons for error correction.
        log_str = "Correcting reads in each amplicon."
        logging.info(log_str)
        counter = itools.count()
        for idx, amplicon in enumerate(amplicons):
            # Write to disk.
            local_path = "{0}amplicon{1}{2}".format(output_dir, idx, os.sep)
            if not os.path.exists(local_path):
                os.mkdir(local_path)
            uncor_amp_path = "{0}uncorrected_reads.fa".format(local_path)
            with open(uncor_amp_path, "w") as amp_file:
                for read in amplicon:
                    amp_file.write(seq_func(next(counter), read).format("fasta"))

            # Realign and correct with kGEM.
            call_indelfixer(args.seq, uncor_amp_path, new_consensus,
                1, local_path, indel_path)
            uncor_amp_path = "{0}reads.sam".format(local_path, idx)
            cor_amp_path = "{0}corrected_reads.fa".format(local_path, idx)
            call_kgem(uncor_amp_path, cor_amp_path, kgem_path)

            # Read back into memory.
            reads = SeqIO.parse(cor_amp_path, "fasta")
            amplicons[idx][:] = map(lambda x: str(x.seq), reads)

        all_corrected = "{0}all_corrected.fa".format(output_dir)
        counter = itools.count()
        with open(all_corrected, "w") as amp_file:
            for amplicon in amplicons:
                for read in amplicon:
                    amp_file.write(seq_func(next(counter), read).format("fasta"))

        call_indelfixer(args.seq, all_corrected, new_consensus,
                1, output_dir, indel_path)
        call_samtools(new_consensus, output_dir)

        # Load new sorted BAM file.
        sam_path = "{0}reads_sorted.bam".format(output_dir)
        sam_file = pysam.Samfile(sam_path, "rb")

        # Try to do some naive indel adjustment.
        log_str = "Beginning final read normalization process on {0}.".format(sam_path)
        logging.info(log_str)
        align_start, align_stop, reads, reference = bioa.reads_from_sam(sam_file, reference)
        log_str = "Finished final read normalization process on {0}.".format(sam_path)
        logging.info(log_str)

        # Rebuild amplicons
        amplicons = bioa.Amplicons(reads, pairs, align_start, align_stop)
        log_str = "Finished correcting amplicon reads."
        logging.info(log_str)

        # Need to infer quasispecies and frequencies.
        logging.info("Beginning read-graph construction.")
        graph = bioa.DecliquedAmpliconReadGraph(amplicons, args.num_mismatch,
                    nthreads=args.num_threads)
        logging.info("Finished read-graph construction.")
        logging.info("Beginning quasispecies reconstruction and " +
            "frequency estimation.")

        qs = None
        # Perform Max Bandwidth if flows == 0.
        if args.flows == 0:
            qs = bioa.max_bandwidth_strategy(graph)
        else:
            qs = bioa.min_unsplittable_flows_resolution(graph, args.flows,
                        nthreads=args.num_threads, timeout=args.timeout)
        logging.info("Finished quasispecies reconstruction and " +
                "frequency estimation.")

        # Output results.
        counter = itools.count()
        output_pop = "{0}quasispecies.fa".format(output_dir)
        map_func = lambda variant: SeqRecord(Seq(variant),
                                    id="Variant{0}".format(next(counter)),
                                    description="Frequency: {0}".format(qs[variant]))
        SeqIO.write(itools.imap(map_func, qs), output_pop, "fasta")
    except Exception as exp:
        sys.stderr.write(exp.message + os.linesep)
        logging.error(exp.message)
        return 1

    # We're all done here!
    return 0


if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))
