#!/usr/bin/env python
'''
Simple Smith-Waterman aligner
'''

import sys
import os
import swalign


def usage():
    sys.stderr.write(__doc__)
    sys.stderr.write('''
Usage: swalign {options} ref query

Reference and query arguments can either be written on the command-line, read
from stdin, or read as FASTA format files. If there is more than one sequence
in the reference FASTA file, the query will be aligned to all reference
sequences and only the best scoring alignment will be displayed. If more than
one sequence is in a query FASTA file, each query sequence will be aligned to
the reference.

Alignments will be made in both forward and reverse directions.

Options:
  -m N         Match score (default: 2)
  -mm N        Mismatch penalty (default: 1)
  -gap N       Gap penalty (default: 1)
  -gapext N    Gap extension penalty (default: 1)
  -gapdecay N  Decay the gap extension penalty (default: 0.0)
  -wrap N      Wrap alignments when they are longer than N bases

Example:
    ~$ swalign AAGGGGAGGACGATGCGGATGTTC AGGGAGGACGATGCGG

''')
    sys.exit(1)


if __name__ == '__main__':
    ref = None
    query = None

    match = 2
    mismatch = -1
    gap_penalty = -1
    gap_extension_penalty = -1
    gap_extension_decay = 0.0
    wrap = None
    verbose = False
    globalalign = False

    last = None

    for arg in sys.argv[1:]:
        if last == '-m':
            match = int(arg)
            last = None
        elif last == '-mm':
            mismatch = -int(arg)
            if mismatch > 0:
                mismatch = -mismatch
            last = None
        elif last == '-gap':
            gap_penalty = -int(arg)
            if gap_penalty > 0:
                gap_penalty = -gap_penalty
            last = None
        elif last == '-gapext':
            gap_extension_penalty = -int(arg)
            if gap_extension_penalty > 0:
                gap_extension_penalty = -gap_extension_penalty
            last = None
        elif last == '-gapdecay':
            gap_extension_decay = float(arg)
            if gap_extension_decay < 0:
                gap_extension_decay = -gap_extension_decay
            last = None
        elif last == '-wrap':
            wrap = int(arg)
            last = None
        elif arg in ['-m', '-mm', '-gap', '-gapext', '-gapdecay', '-wrap']:
            last = arg
        elif arg == '-global':
            globalalign = True
        elif arg == '-v':
            verbose = True
        elif not ref:
            if os.path.exists(arg) or arg == '-':
                ref = swalign.fasta_gen(arg)
            else:
                ref = swalign.seq_gen('cmdline', arg)
        elif not query:
            if os.path.exists(arg) or arg == '-':
                query = swalign.fasta_gen(arg)
            else:
                query = swalign.seq_gen('cmdline', arg)

    if not ref or not query:
        usage()

    sw = swalign.LocalAlignment(
        swalign.NucleotideScoringMatrix(match, mismatch),
        gap_penalty, gap_extension_penalty,
        gap_extension_decay=gap_extension_decay, verbose=verbose, globalalign=globalalign)

    for q_name, q_seq in query():
        best = None
        for r_name, r_seq in ref():
            for strand in '+-':
                if strand == '-':
                    aln = sw.align(r_seq, swalign.revcomp(q_seq),
                        ref_name=r_name, query_name=q_name, rc=True)

                else:
                    aln = sw.align(r_seq, q_seq, ref_name=r_name,
                        query_name=q_name)

                if not best or aln.score > best.score:
                    best = aln
        best.dump(wrap)
        print ""
