#!/usr/bin/env python
#-*- coding: utf-8 -*-

from __future__ import division
import sys

import argparse
import numpy as np
from matplotlib import use
use('Agg')
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from matplotlib import colors as pltcolors

# own modules
from deeptools import parserCommon
from deeptools import heatmapper

debug = 0
plt.ioff()


def parseArguments(args=None):
    parser = argparse.ArgumentParser(
        parents=[parserCommon.heatmapperMatrixArgs(),
                 parserCommon.heatmapperOutputArgs( mode='profile' ),
                 parserCommon.heatmapperOptionalArgs( mode='profile' )],
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description='This tool creates a profile plot for a '
        'score associated to genomic regions. '
        'Typically, these regions are genes, but '
        'any other regions defined in a BED or GFF '
        'format will work. A preprocessed matrix generated '
        'by the tool computeMatrix is required.',
        epilog='An example usage is: %(prog)s -m <matrix file>',
        add_help=False)

    args = parser.parse_args(args)

    # Because of galaxy, the value of this variables is normally
    # set to ''. Therefore this check is needed
    for attr in ['yMax', 'yMin']:
        try:
            args.__setattr__(attr, float(args.__getattribute__(attr)))
#       except ValueError, TypeError:
        except:
            args.__setattr__(attr, None)


    if  args.plotHeight < 0.5:
        args.plotHeight = 0.5
    elif args.plotHeight > 100:
        args.plotHeight = 100

    if args.regionsLabel != 'genes':
        args.regionsLabel = \
            [x.strip() for x in args.regionsLabel.split(',')]

        if len(set(args.regionsLabel)) != len(args.regionsLabel):
            print "The group labels given contain repeated names. Please "
            "give a unique name to each value. The values given are "
            "{}\n".format(args.regionsLabel)
            exit(1)
    else:
        args.regionsLabel = []

    return(args)


def plot_single(ax, ma, average_type, color, label,
                 plot_type='simple'):

    """
    Adds a line to the plot using the specified metod
    """

    sumry = np.__getattribute__(average_type)(ma, axis=0)
    # only plot the average profiles without error regions
    if plot_type != 'overlapped_lines':
        ax.plot(sumry, color=color, label=label, alpha=0.9)
        x = np.arange(len(sumry))
    if plot_type == 'fill':
        ax.fill_between(x, sumry, facecolor=color, alpha=0.6)

    elif plot_type == 'se':  #standard error
        std = np.std(ma, axis=0)/np.sqrt(ma.shape[0])
        alpha = 0.2
        if type(color) == type((0,0)):  # check of color is tuple
            # add the alpha channed to the color tuple
            f_color = [c for c in color[:-1]] + [alpha]
        else:
            f_color = pltcolors.colorConverter.to_rgba(color, alpha)
        # ideally the edgecolor should be None,
        # but when generating a pdf image an edge is still
        # drawn.
        ax.fill_between(x, sumry, sumry + std, facecolor=f_color,
                        edgecolor=f_color, lw=0.01)
        ax.fill_between(x, sumry, sumry - std, facecolor=f_color,
                        edgecolor=f_color, lw=0.01)

    elif plot_type == 'overlapped_lines':
        ax.patch.set_facecolor('black')
        for row in ma:
            ax.plot(row, 'yellow', alpha=0.1)
        x = np.arange(len(row))
    ax.set_xlim(0, max(x))


def plotProfile(hm, outFileName,
                plotTitle='',
                xAxisLabel='', yAxisLabel='',
                yMin=None, yMax=None,
                averageType='median',
                referencePointLabel='TSS',
                startLabel='TSS', endLabel="TES",
                plotHeight=5,
                plotWidth=8,
                onePlotPerGroup=False,
                plotType='simple',
                image_format=None,
                color_list=None):

    """
    Using the hm matrix, makes a line plot per group 
    using the specified parameters.
    """
    tickPlotAdj = 0.5
    plt.rcParams['font.size'] = 10.0
    fontP = FontProperties()
    fontP.set_size('small')
#    rcParams['font.size'] = 9.0

    # convert cm values to inches
    plotHeightInches = float(plotHeight) / 2.54
    figWidth = float(plotWidth) / 2.54

    # measures are in inches
    topBorder = 0.36
    spacer = 0.25
    bottomBorder = 0.3

    numGroups = len(hm.matrixDict)

    if numGroups == 1:
        # when the number of regions is just one, is better
        # to use the defaults for one plot per group, because
        # it does not tries to plot a legend below the plot,
        # which for one region is unecesary
        onePlotPerGroup = True

    # plotHeightInches  height of plot plus bottom spacing
    if not onePlotPerGroup:
        figHeight = topBorder + bottomBorder + (numGroups / 2) * \
            spacer + plotHeightInches
    else:
        figHeight = plotHeightInches

    sumFracHeight = float(plotHeightInches) / figHeight
    # fraction of the height for summary plots

    topBorderFrac = topBorder / figHeight

    if debug:
        print (figWidth, figHeight)
    # figsize: w,h tuple in inches
    fig = plt.figure(figsize=(figWidth, figHeight))

    b = hm.parameters['upstream']
    a = hm.parameters['downstream']
    m = hm.parameters['body']
    w = hm.parameters['bin size']

    if b < 1e5:
        quotient = 1000
        symbol = 'Kb'
    if b >= 1e5:
        quotient = 1e6
        symbol = 'Mb'

    if m == 0:
        xTicks = [int(k / float(w)) - tickPlotAdj for k in [0, b, b  + a]]
        xTicksLabel = ['{0:.1f}'.format(-(float(b) / quotient)),
                       referencePointLabel,
                       '{0:.1f}{1}'.format(float(a) / quotient, symbol)]
    else:
        xticks_values = [0]
        xTicksLabel = []

        # if no upstream region is set, do not set a x tick
        if hm.parameters['upstream'] > 0:
            xticks_values.append( hm.parameters['upstream'] )
            xTicksLabel.append( '{0:.1f}'.format(-(float(b) / quotient)) )
        # set the x tick for the body parameter, regardless if upstream is 0 (not set)
        xticks_values.append( hm.parameters['upstream'] + hm.parameters['body'] )
        xTicksLabel.append( startLabel )
        xTicksLabel.append( endLabel )
        if hm.parameters['downstream'] > 0:
            xticks_values.append( hm.parameters['upstream'] + hm.parameters['body'] + hm.parameters['downstream'] )
            xTicksLabel.append( '{0:.1f}{1}'.format(float(a) / quotient, symbol) )

        xTicks = [int(k / float(w)) - tickPlotAdj
                  for k in xticks_values]

        """
        xTicks = [int(k / float(w)) - tickPlotAdj
                  for k in [0, b, b + m, b + m + a]]
        xTicksLabel = ['{0:.1f}'.format(-(float(b) / quotient)),
                       startLabel, endLabel,
                       '{0:.1f}kb'.format(float(a) / quotient, symbol)]
        """

    fig.suptitle(plotTitle, y=1 - (0.06 / figHeight))

    if not color_list:
        cmap_plot = plt.get_cmap('jet')
        color_list = cmap_plot(np.arange(numGroups)/numGroups)
    else:
        if len(color_list) < numGroups:
            sys.stderr.write("\nThe given list of colors is too small, "
                       "at least {} colors are needed\n".format(numGroups))
            exit(1)
        for color in color_list:
            if not pltcolors.is_color_like(color):
                sys.stderr.write("\nThe color name {} is not valid. Check "
                                 "the name or try with a html hex string "
                                 "for example #eeff22".format(color))

                exit(1)
                
    # add_axes( rect ) where rect is [left, bottom, width, height] and
    # all quantities are in fractions of figure width and height.
    # summary plot
    # each group needs its axe
    left = 0.45 / figWidth
    width = 0.7

    # i.e if multipleLines in one plot:
    if not onePlotPerGroup:
        bottom = 1 - (topBorderFrac + sumFracHeight)
        if debug:
            print ([left, bottom, width, sumFracHeight, figHeight])
        ax = fig.add_axes([left, bottom, width, sumFracHeight])

    index = -1
    for label, ma in hm.matrixDict.iteritems():
        index += 1
        if onePlotPerGroup:
            # create an axis for each sub plot
            bottom = 1 - topBorderFrac - (index + 1) * sumFracHeight
            ax = fig.add_subplot(numGroups, 1, index + 1, label=label)
#            ax = fig.add_axes([left, bottom, width, sumFracHeight])
            plot_single(ax, ma, averageType,
                         color_list[index], label,
                         plot_type=plotType)
            ax.set_ylim(yMin, yMax)
            ax.axes.set_xticks(xTicks)
            ax.axes.set_xticklabels(xTicksLabel)
            # reduce the number of yticks by half
            if index == 0:
                #ax.axes.set_title(plotTitle, weight='bold');
                numTicks = len(ax.get_yticks())
                yTicks = [ax.get_yticks()[i]
                          for i in range(1, numTicks, 2)]
            ax.set_yticks(yTicks)
            ax.axes.set_ylabel(yAxisLabel)

            ax.legend(loc='upper right', prop=fontP,
                      frameon=False, markerscale=0.5)

        else:
            # add new lines to existing plot
            plot_single(ax, ma, averageType,
                         color_list[index], label,
                         plot_type=plotType)

    # i.e if multipleLines in one plot:
    if not onePlotPerGroup:
        # in the case of one box with all plots the
        #font of the legend and the positions
        # are changed.
        ax.set_ylim(yMin, yMax)
        ax.axes.set_xticks(xTicks)
        ax.axes.set_xticklabels(xTicksLabel)
        #ax.axes.set_title(plotTitle, weight='bold');
        # the legend shows in a box below the plot
        ax.legend(bbox_to_anchor=(-0.05, -1.13, 1., 1),
                  loc='upper center',
                  ncol=2, mode="expand", prop=fontP,
                  frameon=False, markerscale=0.5)

        # reduce the number of yticks by half
        numTicks = len(ax.get_yticks())
        yTicks = [ax.get_yticks()[i] for i in range(1, numTicks, 2)]
        ax.set_yticks(yTicks)

        ax.axes.set_ylabel(yAxisLabel)

    # for some reason if the bbox_inches is used
    # then the legend for the summary plot is omited
    plt.savefig(outFileName, bbox_inches='tight', 
        pdd_inches=0, dpi=200, format=image_format)


def main(args):
    r"""
    >>> import filecmp
    >>> import os
    >>> args = parseArguments(
    ... "-m ../deeptools/test/test_heatmapper/master.mat.gz \
    ... --outFileName /tmp/_test.png --regionsLabel uno,dos \
    ... --plotType std".split())
    >>> main(args)
    >>> filecmp.cmp(
    ... '../deeptools/test/test_heatmapper/profile_master.png',
    ... '/tmp/_test.png')
    True
    >>> os.remove('/tmp/_test.png')

    """

    hm = heatmapper.heatmapper()
    matrix_file = args.matrixFile.name
    args.matrixFile.close()
    hm.readMatrixFile(matrix_file,
                      default_group_name=args.regionsLabel)

    if args.kmeans is not None:
        k = args.kmeans
        idx = hm.hmcluster(hm.matrixDict[hm.matrixDict.keys()[0]], 
                           k,
                           method='kmeans')

        # split the matrix, region and averages in to clusters
        label = hm.matrixDict.keys()[0]
        _matrixDict = {}
        _regionsDict = {}
        _avgDict = {}
        for cluster in range(k):
            c_label = "cluster {}".format(cluster+1)
            cluster_ids = np.flatnonzero(idx==cluster)
            _matrixDict[c_label] = \
                hm.matrixDict[label][cluster_ids,]
            _regionsDict[c_label] = \
                hm.regionsDict[label][cluster_ids]

        hm.matrixDict = _matrixDict
        hm.regionsDict = _regionsDict

    if len(args.regionsLabel):
        hm.reLabelGroups(args.regionsLabel)

    if args.outFileNameData:
        hm.saveTabulatedValues(args.outFileNameData)

    if args.outFileSortedRegions:
        hm.saveBED(args.outFileSortedRegions)

    plotProfile(hm,
                args.outFileName,
                args.plotTitle,
                args.xAxisLabel, args.yAxisLabel,
                args.yMin, args.yMax,
                args.averageType,
                args.refPointLabel,
                args.startLabel,
                args.endLabel,
                args.plotHeight,
                args.plotWidth,
                args.onePlotPerGroup,
                plotType=args.plotType,
                image_format=args.plotFileFormat,
                color_list=args.colors)


if __name__ == "__main__":
    args = parseArguments()
    main(args)
