#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (C) 2010 - 2012, A. Murat Eren
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation; either version 2 of the License, or (at your option)
# any later version.
#
# Please read the COPYING file.

import sys
import Oligotyping.lib.fastalib as u
from Oligotyping.utils.aligner import nw_align
from Oligotyping.utils.utils import Progress

progress = Progress()


def get_trim_loc(seq):
    # to find longest terminal gap and trim both sequences to the last base they both have...
    for p in range(0, len(seq)):
        if seq[len(seq) - p -1] != '-':
            break
    return p


def remove_common_gaps(seq1, seq2):
    s1, s2 = '', ''
    for m in range(0, len(seq1)):
        if seq1[m] == '-' and seq2[m] == '-':
            continue
        s1 += seq1[m]
        s2 += seq2[m]
    return s1, s2


def main(input_file, output_file, align = False):
    sequences = {}
    similarities = {}
    fasta = u.SequenceSource(input_file)

    while fasta.next():
        sequences[fasta.id] = fasta.seq
        similarities[fasta.id] = {}

    keys = sequences.keys()

    progress.new('Processing sequences')
    for i in range(0, len(keys)):
        progress.update('%d of %d' % (i + 1, len(keys)))

        if not align:
            p1 = get_trim_loc(sequences[keys[i]])

        for j in range(i, len(keys)):
            if i == j:
                continue

            if align:
                seq1, seq2 = nw_align(sequences[keys[i]].replace('-', ''), sequences[keys[j]].replace('-', ''))
                p1 = get_trim_loc(seq1)
            else:
                seq1, seq2 = sequences[keys[i]], sequences[keys[j]]

            p2 = get_trim_loc(seq2)

            trim = p1 if p1 > p2 else p2
            if trim:
                seq1 = seq1[:-trim]
                seq2 = seq2[:-trim]
                
            if (seq1 + seq2).find('-') > -1:
                # yes, this is very expensive, but common gaps must be removed
                # before assessing percent similarity...
                seq1, seq2 = remove_common_gaps(seq1, seq2)

            percent_identity = 100 - len([True for x in range(0, len(seq1)) if seq1[x] != seq2[x]]) * 100.0 / len(seq1)

            similarities[keys[i]][keys[j]] = similarities[keys[j]][keys[i]] =percent_identity

    progress.end()

    output = open(output_file, 'w')
    output.write('\t'.join([''] + keys) + '\n')
    for key1 in keys:
        line = [key1]
        for key2 in keys:
            line.append('%.2f' % (100 if key1 == key2 else similarities[key1][key2]))
        output.write('\t'.join(line) + '\n')
    output.close()
        

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Generates a ditance matrix for all sequences in a given FASTA file')
    parser.add_argument('input_file', metavar = 'FASTA',
                        help = 'FASTA file that contains -representative?- sequences')
    parser.add_argument('-o', '--output_file', metavar = 'FILE', default = None,
                        help = 'Output file to store results')
    parser.add_argument('-A', '--align', action="store_true", default = False,
                        help = 'If sequences require pairwise alignment')

    args = parser.parse_args()

    if not args.output_file:
        args.output_file = args.input_file + '-DIST.txt'
    
    sys.exit(main(args.input_file, args.output_file, args.align))
