#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (C) 2010 - 2012, A. Murat Eren
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation; either version 2 of the License, or (at your option)
# any later version.
#
# Please read the COPYING file.

import sys
import argparse
import numpy
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import Oligotyping.lib.fastalib as u
from Oligotyping.utils.utils import pretty_print


def length_distribution(fasta, output = None, title = None):
    fasta = u.SequenceSource(fasta)

    sequence_lengths = []
    
    fasta.reset()
    
    while fasta.next():
        if fasta.pos % 1000 == 0 or fasta.pos == 1:
            sys.stderr.write('\r[fastalib] Reading: %s' % (fasta.pos))
            sys.stderr.flush()
        sequence_lengths.append(len(fasta.seq.replace('-', '')))
    
    fasta.reset()
    
    sys.stderr.write('\n')
    
    max_seq_len = max(sequence_lengths) + (int(max(sequence_lengths) / 100.0) or 10)
    
    seq_len_distribution = [0] * (max_seq_len + 1)
    
    for l in sequence_lengths:
        seq_len_distribution[l] += 1
    
    fig = plt.figure(figsize = (12, 8))
    plt.rcParams.update({'axes.linewidth' : 0.9})
    plt.rc('grid', color='0.50', linestyle='-', linewidth=0.1)
    
    gs = gridspec.GridSpec(20, 1)

    #############################################################################################################
    
    ax1 = plt.subplot(gs[1:3])
    plt.subplots_adjust(left=0.05, bottom = 0.03, top = 0.95, right = 0.98)
    plt.grid(False)
    plt.yticks([])
    plt.xticks([])
    total_seqs = len(sequence_lengths)
    plt.text(0.02, 0.5, 'total: %s / mean: %.2f / std: %.2f / min: %s / max: %s'\
        % (pretty_print(total_seqs),
           numpy.mean(sequence_lengths), numpy.std(sequence_lengths),\
           min(sequence_lengths),\
           max(sequence_lengths)),\
        va = 'center', alpha = 0.8, size = 12)
   
    #############################################################################################################
 
    ax1 = plt.subplot(gs[4:11])
    plt.grid(True)
    plt.subplots_adjust(left=0.05, bottom = 0.01, top = 0.95, right = 0.98)
    
    plt.plot(seq_len_distribution, color = 'black', alpha = 0.3)
    plt.fill_between(range(0, max_seq_len + 1), seq_len_distribution, y2 = 0, color = 'black', alpha = 0.30)
    plt.ylabel('number of sequences')
    
    xtickstep = (max_seq_len / 50) or 1
    ytickstep = max(seq_len_distribution) / 20 or 1
    
    plt.xticks(range(xtickstep, max_seq_len + 1, xtickstep), rotation=90, size='xx-small')
    plt.yticks(range(0, max(seq_len_distribution) + 1, ytickstep),
               [y for y in range(0, max(seq_len_distribution) + 1, ytickstep)],
               size='xx-small')
    plt.xlim(xmin = 0, xmax = max_seq_len)
    plt.ylim(ymin = 0, ymax = max(seq_len_distribution) + (max(seq_len_distribution) / 20.0))
    
    plt.figtext(0.5, 0.96, '%s' % (title or fasta.fasta_file_path), weight = 'black', size = 'xx-large', ha = 'center')
    
   
    #############################################################################################################
    
    ax2 = plt.subplot(gs[12:19])
    plt.subplots_adjust(left=0.05, bottom = 0.01, top = 0.95, right = 0.98)
    plt.grid(True)

    length_abundance = {}
    for l in sequence_lengths:
        if length_abundance.has_key(l):
            length_abundance[l] += 1
        else:
            length_abundance[l] = 1

    percentages = []
    total_percentage = 0
    for i in range(0, max_seq_len):
        if length_abundance.has_key(i):
            total_percentage += length_abundance[i] * 100.0 / total_seqs
            percentages.append(total_percentage)
        else:
            percentages.append(total_percentage)

    xtickstep = (max_seq_len / 50) or 1
    plt.xticks(range(xtickstep, max_seq_len + 1, xtickstep), rotation=90, size='xx-small')
    plt.yticks(range(0, 101, 5),
               ['%d%%' % y for y in range(0, 101, 5)],
               size='xx-small')
    plt.ylabel('percent of reads')

    plt.xlim(xmin = 0, xmax = max_seq_len)
    plt.ylim(ymin = 0, ymax = 100)
    plt.plot(percentages)
    plt.fill_between(range(0, max_seq_len + 1), percentages + [100], y2 = 0, color = 'blue', alpha = 0.30)
    
 
    #############################################################################################################
    
    if output == None:
        output = fasta.fasta_file_path

    try:
        plt.savefig(output + '.pdf')
    except:
        plt.savefig(output + '.png')
    
    try:
        plt.show()
    except:
        pass
    

    fasta.close()
    
    return

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Length Distribution of Reads in a FASTA File')
    parser.add_argument('fasta', metavar = 'FASTA',
                        help = 'FASTA formatted sequences file')
    parser.add_argument('-o', '--output', help = 'Output file name to store distribution figure', default = None)
    parser.add_argument('-t', '--title', help = 'Title for the figure', default = None)


    args = parser.parse_args()
    length_distribution(args.fasta)
