#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (C) 2010 - 2012, A. Murat Eren
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation; either version 2 of the License, or (at your option)
# any later version.
#
# Please read the COPYING file.

#
# python $me fasta.fa quals_dict qual_stats_dict
#
# don't have them? quals_dict and qual_stats_dict are being generated by gen_dicts_for_qual_stats.py
#

import sys
import cPickle
from scipy import log2 as log

import matplotlib.pyplot as plt
import Oligotyping.lib.fastalib as u

from Oligotyping.utils.utils import get_qual_stats_dict
from Oligotyping.utils.random_colors import get_list_of_colors

COLORS = {'A': 'red',
          'T': 'blue',
          'C': 'green',
          'G': 'purple',
          'N': 'white'}

alignment = u.SequenceSource(sys.argv[1])
quals_dict = cPickle.load(open(sys.argv[2]))

quals_dict_filtered = {}

ids_in_alignment_file = []
while alignment.next():
    ids_in_alignment_file.append(alignment.id)
ids_in_alignment_file = set(ids_in_alignment_file)

for read_id in quals_dict:
    if read_id in ids_in_alignment_file:
        quals_dict_filtered[read_id] = quals_dict[read_id]
        ids_in_alignment_file.remove(read_id)

qual_stats_dict = get_qual_stats_dict(quals_dict_filtered)

colors = get_list_of_colors(21, colormap="RdYlGn")
colors = [colors[0] for _ in range(0, 20)] + colors
max_count = max([qual_stats_dict[q]['count'] for q in qual_stats_dict if qual_stats_dict[q]])

alignment_length = len(quals_dict.values()[0])

fig = plt.figure(figsize = (25, 8))
plt.rc('grid', color='0.50', linestyle='-', linewidth=0.1)
plt.grid(True)

plt.subplots_adjust(left=0.02, bottom = 0.09, top = 0.95, right = 0.98)

for position in range(0, alignment_length):
    print position
    
    if not qual_stats_dict[position]:
        continue

    column_q_list = [quals_dict[read_id][position] for read_id in quals_dict if quals_dict[read_id][position]]

    if column_q_list:
        b = plt.boxplot(column_q_list, positions=[position + 0.5], sym=',', widths=0.9)
        plt.setp(b['medians'], color='red')
        plt.setp(b['whiskers'], color='black', alpha=0.6)
        plt.setp(b['boxes'], color='black', alpha=0.8)
        plt.setp(b['caps'], color='black', alpha=0.6)
        plt.setp(b['fliers'], color='#EEEEEE', alpha=0.01)
        
        mean = int(round(qual_stats_dict[position]['mean']))
        count = qual_stats_dict[position]['count']
        plt.fill_between([position, position + 1], y1 = 0, y2 = 41, color = colors[mean], alpha = (log(count) / log(max_count)) / 5)

plt.ylim(ymin=0, ymax=41)
plt.xlim(xmin=0, xmax=alignment_length)

plt.show()
