#!/usr/bin/env python3.2

from __future__ import division, print_function

import argparse, sys

from json import dump as json_dump
from os.path import exists

from Bio import SeqIO, AlignIO

from BioExt import hxb2, nl4_3

from hy454 import align_to_refseq, to_positional


_refseqs = {
    'HXB2_env': hxb2.env,
    'HXB2_gag': hxb2.gag,
    'HXB2_int': hxb2.int,
    'HXB2_nef': hxb2.nef,
    'HXB2_pol': hxb2.pol,
    'HXB2_pr': hxb2.pr,
    'HXB2_prrt': hxb2.prrt,
    'HXB2_rev': hxb2.rev,
    'HXB2_rt': hxb2.rt,
    'HXB2_tat': hxb2.tat,
    'HXB2_vif': hxb2.vif,
    'HXB2_vpr': hxb2.vpr,
    'HXB2_vpu': hxb2.vpu,
    'NL4-3_prrt': nl4_3.prrt
}


def positional_write(msa, fh):
    json_dump(to_positional(msa), fh, separators=(',', ':'))
    fh.write('\n')


def main(infile=None, reference=None, outfile=None, positional=False, expected_identity=None, keep_insertions=False, discardfile=None):

    if infile is None:
        infile = sys.stdin

    if outfile is None:
        outfile = sys.stdout

    if positional is True:
        write = positional_write
    else:
        write = lambda msa, fh: SeqIO.write(msa, fh, 'fasta') if keep_insertions else AlignIO.write(msa, fh, 'fasta')

    if expected_identity is None:
        expected_identity = 0.

    seqrecords = [r for r in SeqIO.parse(infile, 'fasta')]

    if reference is None:
        refseq = seqrecords.pop(0)
    else:
        refseq = _refseqs[args.reference].load()

    msa, discarded = align_to_refseq(refseq, seqrecords, expected_identity=expected_identity, keep_insertions=keep_insertions)

    write(msa, outfile)

    if discardfile is not None and len(discarded) > 0:
        SeqIO.write(discarded, discardfile, 'fasta')

    return 0



if __name__ == '__main__':
    def probability(string):
        try:
            p = float(string)
            if p < 0 or p > 1:
                raise ValueError()
            return p
        except ValueError:
            msg = "'%s' is not a probability in [0, 1]" % string
            raise argparse.ArgumentTypeError(msg)

    parser = argparse.ArgumentParser(
        description='align sequences to a reference using a codon alignment algorithm, returning FASTA by default'
    )

    # --positional requires MSA form, which --insertions precludes
    group = parser.add_mutually_exclusive_group()

    parser.add_argument(
        'input',
        metavar='FASTAFILE',
        type=argparse.FileType('r'),
        help='unaligned FASTA file'
    )
    parser.add_argument(
        '-r', '--reference',
        choices=sorted(_refseqs.keys()),
        help='use a provided default reference sequence'
    )
    # add this to the mutually exclusive group
    group.add_argument(
        '-p', '--positional',
        action='store_true',
        help='return JSON-formatted reference-relative codon positional format'
    )
    parser.add_argument(
        '-o', '--output',
        type=argparse.FileType('w'),
        default=sys.stdout,
        help='save alignment to OUTPUT'
    )
    parser.add_argument(
        '-i', '--identity',
        metavar='PROBABILITY',
        type=probability,
        default=0.,
        help='discard sequences that are insufficiently identical to the reference'
    )
    parser.add_argument(
        '-d', '--discard',
        metavar='DISCARDFILE',
        type=argparse.FileType('w'),
        help='discarded sequences are sent to DISCARDFILE'
    )
    # add to the mutually exclusive group
    group.add_argument(
        '-I', '--insertions',
        action='store_true',
        help='keep insertions (output is not an MSA)'
    )

    args = None
    retcode = -1
    try:
        args = parser.parse_args()
        retcode = main(args.input, args.reference, args.output, args.positional, args.identity, args.insertions, args.discard)
    finally:
        if args is not None:
            if not args.input in (None, sys.stdin):
                args.input.close()
            if not args.output in (None, sys.stdout):
                args.output.close()
            if not args.discard in (None, sys.stdout):
                args.discard.close()
    sys.exit(retcode)
