#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (C) 2010 - 2011, A. Murat Eren
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation; either version 2 of the License, or (at your option)
# any later version.
#
# Please read the COPYING file.

import os
import sys
import numpy as np
import matplotlib.pyplot as plt

base_pos    = {'-': 5, 'A': 4, 'T': 3, 'C': 2, 'G': 1}
base_colors = {'-': 'white', 'A': 'red', 'T': 'green', 'C': 'blue', 'G': 'yellow'}

def oligotype_network_structure(environment_file_path, output_dir = None):
    tuples = [line.strip().split("\t") for line in open(environment_file_path)]
    
    samples_dict = {}
    
    for tpl in tuples:
        oligo, sample, count = tpl

        m = []
        for base in oligo:
            if base not in base_pos:
                print 'Error: The environment file does not seem to be generated by an oligotyping analysis.'
                sys.exit(-1)
            m.append(base_pos[base])
    
        if samples_dict.has_key(sample):
            samples_dict[sample][oligo] = (m, int(count))
        else:
            samples_dict[sample] = {oligo: (m, int(count))}
    
    
    for sample in samples_dict:
        total_reads = sum([x[1] for x in samples_dict[sample].values()])
    
        N = len(samples_dict[sample].keys()[0])
        ind = np.arange(N) + 1
    
        fig = plt.figure(figsize = (N + 2, 6))
    
        plt.rcParams.update({'axes.linewidth' : 0.1})
        plt.rc('grid', color='0.70', linestyle='-', linewidth=0.1)
        plt.grid(True)
    
        plt.subplots_adjust(hspace = 0, wspace = 0, right = 0.995, left = 0.025, top = 0.92, bottom = 0.05)
    
        ax = fig.add_subplot(111)
        ax.plot([0], [0], visible = False)
        ax.plot([N], [0], visible = False)
        ax.plot([0], [6], visible = False)
    
    
        for pos in range(0, len(oligo)):
            bases = {}
            for oligo in samples_dict[sample]:
                base = oligo[pos]
                if bases.has_key(base):
                    bases[base] += samples_dict[sample][oligo][1]
                else:
                    bases[base] = samples_dict[sample][oligo][1]
            for base in bases:
                ratio = bases[base] * 1.0 / total_reads
                ax.plot([pos + 1], [base_pos[base]], 'o', c = 'white', lw = 1, ls="--", alpha = 0.75, ms = ratio * 100)
                ax.plot([pos + 1], [base_pos[base]], 'o', c = base_colors[base], lw = 1, alpha = ratio / 5, ms = ratio * 100)
    
    
        for oligo in samples_dict[sample]:
            ratio = samples_dict[sample][oligo][1] * 1.0 / total_reads
            ax.plot(ind, samples_dict[sample][oligo][0], c = 'black', solid_capstyle = "round", solid_joinstyle = "round", lw = 8, alpha = ratio / 5)
            ax.plot(ind, samples_dict[sample][oligo][0], c = 'black', solid_capstyle = "round", solid_joinstyle = "round", lw = 6, alpha = ratio / 3)
            ax.plot(ind, samples_dict[sample][oligo][0], c = 'black', solid_capstyle = "round", solid_joinstyle = "round", lw = 4, alpha = ratio)
            ax.plot(ind, samples_dict[sample][oligo][0], c = 'black', solid_capstyle = "round", solid_joinstyle = "round", lw = 1, alpha = 0.1)
            ax.plot(ind, samples_dict[sample][oligo][0], c = 'white', solid_capstyle = "round", solid_joinstyle = "round", lw = 1, alpha = ratio)
    
        plt.yticks(np.arange(6), ('', 'G', 'C', 'T', 'A', '--'), size = 'x-large')
        plt.title(sample + " (total reads: %s)" % total_reads)
    
        locs = range(0, N + 2)
        plt.xticks(locs, [''] + ["VL " + str(x) for x in range(0, len(locs))[1:-1]] + [''])

        if output_dir:
            plt.savefig(os.path.join(output_dir,  sample + ".png"))
        else:
            plt.show()

        plt.clf()
        plt.close('all')


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Oligotype Network Structure Per Sample')
    parser.add_argument('environment_file', metavar = 'ENVIRONMENT_FILE', help = 'Environment file generated by oligotyping analysis')
    parser.add_argument('-o', '--output-dir', default = None, metavar = 'OUTPUT_DIR',\
                        help = 'Directory path in which "[sample].png" files will be stored for each sample.')


    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        try:
            os.makedirs(args.output_dir)
        except:
            print "Error: Attempt to create the output directory ('%s') have failed" % args.output_dir
            sys.exit(-1)

    sys.exit(oligotype_network_structure(args.environment_file, args.output_dir))   
