#!/usr/bin/env python2
#
# Copyright John Reid 2009, 2012, 2013
#

"""
Code that reads in sequences from a FASTA file and counts occurrences
of characters.
"""


import corebio.seq_io.fasta_io as F
import corebio.seq
import numpy
from optparse import OptionParser
import bioinfutils


def weblogo_color_to_tuple(col):
    """Convert a weblogolib color object to a RGB tuple.
    """
    return col.red, col.green, col.blue


#
# Parse the options
#
option_parser = OptionParser()
option_parser.add_option(
    '--pie',
    action='store_true',
    default=False,
    help="Draw pie charts of the sequence compositions."
)
options, args = option_parser.parse_args()

for fasta in args:
    #
    # Tally the sequences
    #
    alphabet = corebio.seq.reduced_nucleic_alphabet
    tally = numpy.zeros(len(alphabet), dtype=int)
    num_seqs = 0
    for seq in F.iterseq(
        bioinfutils.open_input(fasta),
        alphabet
    ):
        tally += seq.tally()
        num_seqs += 1

    #
    # Get number of known bases
    #
    num_known = tally[:4].sum()
    known_percentages = tally[:4] * 100. / num_known
    total = tally.sum()
    unknown_percentages = tally * 100. / total

    #
    # Print the tally out
    #
    print fasta
    print "    %8d bp in %d sequences" % (total, num_seqs)
    for base, count, known_perc, unknown_perc in zip(
            alphabet, tally, known_percentages, unknown_percentages):
        print "%s : %8d : %5.1f%% : %5.1f%%" % (
            base, count, unknown_perc, known_perc)
    print "N : %8d : %5.1f%%" % (tally[4], unknown_percentages[4])
    print "GC-content            : %5.1f%%" % (
        known_percentages[1] + known_percentages[2])
    print

    if options.pie:
        # Make a pie chart of nucleotide composition.
        from pylab import figure, axes, pie, close
        import weblogolib.colorscheme as schemes
        scheme = schemes.nucleotide

        def color_for_base(b):
            return weblogo_color_to_tuple(scheme.color(b))

        fig = figure(figsize=(2.64, 2.64))
        ax = axes([.0, .0, 1., 1.])
        labels = 'A', 'C', 'G', 'T'
        pie(
            tally[:4],
            labels=labels,
            autopct='%1.1f\\%%',
            shadow=False,
            colors=map(color_for_base, labels))
        pie_filename = '%s-pie.png' % fasta
        print 'Saving pie chart to %s' % pie_filename
        fig.savefig(pie_filename, bbox_inches='tight')
        pie_filename = '%s-pie.eps' % fasta
        print 'Saving pie chart to %s' % pie_filename
        fig.savefig(pie_filename, bbox_inches='tight')
        close()
