#!/usr/bin/env python
#-*- coding: utf-8 -*-

import argparse
from collections import OrderedDict
import numpy as np
from matplotlib import use
use('Agg')
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.font_manager import FontProperties

# own modules
from deeptools import parserCommon
from deeptools import heatmapper

debug = 0
plt.ioff()


def parseArguments(args=None):
    parser = argparse.ArgumentParser(
        parents=[parserCommon.heatmapperMatrixArgs(),
                 parserCommon.heatmapperOutputArgs( mode='heatmap' ),
                 parserCommon.heatmapperOptionalArgs( mode='heatmap' )],
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description='This tool creates a heatmap for a '
        'score associated to genomic regions. '
        'The program requires a preprocessed matrix '
        'generated by the tool computeMatrix.',
        epilog='An example usage is: %(prog)s -m <matrix file>',
        add_help=False)

    args = parser.parse_args(args)

    # Because of galaxy, the value of this variables is normally
    # set to ''. Therefore this check is needed
    for attr in ['zMin', 'zMax', 'yMax', 'yMin']:
        try:
            args.__setattr__(attr, float(args.__getattribute__(attr)))
        #except ValueError, TypeError:
        except:
            args.__setattr__(attr, None)

    args.heatmapHeight = args.heatmapHeight if args.heatmapHeight > 3 and \
        args.heatmapHeight <= 100 else 10

    if not matplotlib.colors.is_color_like(args.missingDataColor):
        print "The value {0}  for --missingDataColor is "
        "not valid".format(args.missingDataColor)
        exit(1)

    if args.regionsLabel != 'genes':
        args.regionsLabel = \
            [x.strip() for x in args.regionsLabel.split(',')]

        if len(set(args.regionsLabel)) != len(args.regionsLabel):
            print "The group labels given contain repeated names. Please "
            "give a unique name to each value. The values given are "
            "{}\n".format(args.regionsLabel)
            exit(1)
    else:
        args.regionsLabel = []

    return(args)


def plot_profile(ax, ma, average_type, color, label,
                 plot_type='simple'):

    sumry = np.__getattribute__(average_type)(ma, axis=0)
    # only plot the average profiles without error regions
    if plot_type != 'overlapped_lines':
        ax.plot(sumry, color=color, label=label, alpha=0.6)
        x = np.arange(len(sumry))
    if plot_type == 'fill':
        ax.fill_between(x, sumry, facecolor=color, alpha=0.6)

    elif plot_type == 'std':
        std = np.std(ma, axis=0)
        ax.fill_between(x, sumry, sumry + std, facecolor=color, alpha=0.2)
        ax.fill_between(x, sumry, sumry - std, facecolor=color, alpha=0.2)

    elif plot_type == 'overlapped_lines':
        ax.patch.set_facecolor('black')
        for row in ma:
            ax.plot(row, 'yellow', alpha=0.1)


def plotMatrix(hm, outFileName,
               colorMap='binary', missingDataColor='black',
               plotTitle='',
               xAxisLabel='', yAxisLabel='', regionsLabel='',
               zMin=None, zMax=None,
               yMin=None, yMax=None,
               averageType='median',
               referencePointLabel='TSS',
               startLabel='TSS', endLabel="TES",
               heatmapHeight=25,
               heatmapWidth=7.5,
               onePlotPerGroup=False, whatToShow='plot, heatmap and scale',
               plotType='simple',
               image_format=None):

    tickPlotAdj = 0.5
    matrixFlatten = None
    if zMin is None:
        matrixFlatten = flattenMatrix(hm.matrixDict)
        # try to avoid outliers by using np.percentile
        zMin = np.percentile(matrixFlatten, 1.0)
        if np.isnan(zMin):
            zMin = None

    if zMax is None:
        if matrixFlatten is None:
            matrixFlatten = flattenMatrix(hm.matrixDict)
        # try to avoid outliers by using np.percentile
        zMax = np.percentile(matrixFlatten, 98.0)
        if np.isnan(zMax):
            zMax = None

    plt.rcParams['font.size'] = 10.0
    # rcParams['font.size'] = 9.0

    showSummaryPlot = False
    showHeatmap = False
    showColorbar = False

    if whatToShow == 'plot and heatmap':
        showSummaryPlot = True
        showHeatmap = True
    elif whatToShow == 'plot only':
        showSummaryPlot = True
    elif whatToShow == 'heatmap only':
        showHeatmap = True
    elif whatToShow == 'colorbar only':
        showColorbar = True
    elif whatToShow == 'heatmap and colorbar':
        showHeatmap = True
        showColorbar = True
    else:
        showSummaryPlot = True
        showHeatmap = True
        showColorbar = True

    sumFigSpacer = 0
    sumFigHeightInches = 0
    spaceBetweenClusterNames = 0

    if showSummaryPlot:
        spaceBetweenClusterNames = 0.25
        sumFigHeightInches = 1.5

    if showSummaryPlot and showHeatmap:
        sumFigSpacer = 0.3

    # the heatmapHeight value is given in cm
    heatmapHeightInches = float(heatmapHeight) / 2.54 if showHeatmap or \
        showColorbar else 0
    figWidth = float(heatmapWidth) / 2.54

    # measures are in inches
    topBorder = 0.36
    bottomBorder = 0.2
    smallSpacer = 0.05

    numGroups = len(hm.matrixDict.keys())

# commented out because if the last group is the
# smallest one, it will fail. Furthermore, the end result may
# be more confusing. Better that the user
# takes action
#    hm.matrixDict = mergeSmallGroups(hm.matrixDict)

    if numGroups == 1:
        # when the number of regions is just one, is better
        # to use the defaults for one plot per group, because
        # it does not tries to plot a legend below the plot,
        # which for one region is unecesary
        onePlotPerGroup = True

    # heatmapHeightInches  height of heatmap plus bottom spacing
    if onePlotPerGroup:
        # decide the figure size based on the number of groups defined
        # Each summary plot occupies 1 x 1.3 inches plus 0.4 inches for
        # spacing. The heatmap occupies heatmapHeightInches inches
        figHeight = topBorder + bottomBorder + numGroups * \
            (sumFigHeightInches + sumFigSpacer) + heatmapHeightInches
    else:
        figHeight = topBorder + bottomBorder + (np.ceil(float(numGroups) / 2)) * \
            spaceBetweenClusterNames + sumFigHeightInches + sumFigSpacer + \
            heatmapHeightInches

    sumFracHeight = float(sumFigHeightInches) / figHeight
    # fraction of the height for summary plots

    # fraction of the height for the whole heatmap
    heatFracHeight = heatmapHeightInches / figHeight

    topBorderFrac = topBorder / figHeight
    bottomBorderFrac = bottomBorder / figHeight
    spacerFrac = sumFigSpacer / figHeight
    smallSpacerFrac = smallSpacer / figHeight

    # figsize: w,h tuple in inches
    fig = plt.figure(figsize=(figWidth, figHeight))

    b = hm.parameters['upstream']
    a = hm.parameters['downstream']
    m = hm.parameters['body']
    w = hm.parameters['bin size']

    if b < 1e5:
        quotient = 1000
        symbol = 'Kb'
    if b >= 1e5:
        quotient = 1e6
        symbol = 'Mb'

    if m == 0:
        xTicks = [int(k / float(w)) - tickPlotAdj for k in [0, b, b  + a]]
        xTicksLabel = ['{0:.1f}'.format(-(float(b) / quotient)),
                       referencePointLabel,
                       '{0:.1f}{1}'.format(float(a) / quotient, symbol)]
    else:
        xticks_values = [0]
        xTicksLabel = []

        # if no upstream region is set, do not set a x tick
        if hm.parameters['upstream'] > 0:
            xticks_values.append( hm.parameters['upstream'] )
            xTicksLabel.append( '{0:.1f}'.format(-(float(b) / quotient)) )
        # set the x tick for the body parameter, regardless if upstream is 0 (not set)
        xticks_values.append( hm.parameters['upstream'] + hm.parameters['body'] )
        xTicksLabel.append( startLabel )
        xTicksLabel.append( endLabel )
        if hm.parameters['downstream'] > 0:
            xticks_values.append( hm.parameters['upstream'] + hm.parameters['body'] + hm.parameters['downstream'] )
            xTicksLabel.append( '{0:.1f}{1}'.format(float(a) / quotient, symbol) )

        xTicks = [int(k / float(w)) - tickPlotAdj
                  for k in xticks_values]

    fig.suptitle(plotTitle, y=1 - (0.06 / figHeight))

    # colormap for the image
    cmap = plt.get_cmap(colorMap)
    # color map for the profile on top of the heatmap
    cmap_plot = plt.get_cmap('jet')
    cmap.set_bad(missingDataColor)  # nans are printed using this color

    # add_axes( rect ) where rect is [left, bottom, width, height] and
    # all quantities are in fractions of figure width and height.
    # summary plot
    # each group needs its axe
    left = 0.45 / figWidth
    width = 0.6

    if not showSummaryPlot:
        # when only the heatmap is to be printed use more space
        left = 0.28 / figWidth
        width = 0.7 - left
    if not showColorbar:
        width = 0.7

    ###### plot summary plot ###########
    if showSummaryPlot:

        # i.e if multipleLines in one plot:
        if not onePlotPerGroup:
            bottom = 1 - (topBorderFrac + sumFracHeight)
            if debug:
                print ([left, bottom, width, sumFracHeight, figHeight])
            ax = fig.add_axes([left, bottom, width, sumFracHeight])

        index = -1
        for label, ma in hm.matrixDict.iteritems():
            index += 1
            if onePlotPerGroup:
                # create an axis for each sub plot
                bottom = 1 - topBorderFrac - (index + 1) * sumFracHeight \
                    - (index * spacerFrac)
                ax = fig.add_axes([left, bottom, width, sumFracHeight])
                plot_profile(ax, ma, averageType,
                             cmap_plot(1. * index / numGroups), label,
                             plot_type=plotType)
                ax.set_ylim(yMin, yMax)
                ax.axes.set_xticks(xTicks)
                ax.axes.set_xticklabels(xTicksLabel)
                # reduce the number of yticks by half
                if index == 0:
#                    ax.axes.set_title(plotTitle, weight='bold');
                    numTicks = len(ax.get_yticks())
                    yTicks = [ax.get_yticks()[i]
                              for i in range(1, numTicks, 2)]
                ax.set_yticks(yTicks)
                ax.axes.set_ylabel(yAxisLabel)

            else:
                # add new lines to existing plot
                plot_profile(ax, ma, averageType,
                             cmap_plot(1. * index / numGroups), label,
                             plot_type=plotType)

        # i.e if multipleLines in one plot:
        if not onePlotPerGroup:
            # in the case of one box with all plots the
            #font of the legend and the positions
            # are changed.
            ax.set_ylim(yMin, yMax)
            ax.axes.set_xticks(xTicks)
            ax.axes.set_xticklabels(xTicksLabel)
#            ax.axes.set_title(plotTitle, weight='bold');
            fontP = FontProperties()
            fontP.set_size('small')
            # the legend shows in a box below the plot
            ax.legend(bbox_to_anchor=(-0.15, -1.2, 1.35, 1),
                      loc='upper center',
                      ncol=2, mode="expand", borderaxespad=0., prop=fontP,
                      frameon=False, markerscale=0.5)

            # reduce the number of yticks by half
            numTicks = len(ax.get_yticks())
            yTicks = [ax.get_yticks()[i] for i in range(1, numTicks, 2)]
            ax.set_yticks(yTicks)

            ax.axes.set_ylabel(yAxisLabel)

    ###### plot heatmap plot ###########
    if showHeatmap:
        startHeatmap = heatFracHeight + bottomBorderFrac
        groupLengths = np.array([len(x) for x in hm.matrixDict.values()],
                                dtype='float64')
        groupLengthsFrac = groupLengths / sum(groupLengths)
        index = -1
        for label, ma in hm.matrixDict.iteritems():
#        for index in range(0, numGroups):
            index += 1
            # maks nans
            ma = np.ma.array(ma, mask=np.isnan(ma))
            # the size of the heatmap is proportional to its length
            bottom = startHeatmap - \
                sum(groupLengthsFrac[:index + 1]) * heatFracHeight
            height = groupLengthsFrac[index] * heatFracHeight - smallSpacerFrac
            axHeat = fig.add_axes([left, bottom, width, height])
            interpolation_type = 'bicubic' if ma.shape[0] > 200 and \
                ma.shape[1] > 1000 else 'nearest'
            img = axHeat.imshow(ma,
                                aspect='auto',
                                interpolation=interpolation_type,
                                origin='upper',
                                vmin=zMin,
                                vmax=zMax,
                                cmap=colorMap,
                                extent=[0, ma.shape[1], ma.shape[0], 0])

            # plot a line showing the border of the gene length
            # only if the matrix is using reference-point (b>0)
            # The lengthDict is set to none if the matrix is
            # not ordered according to region length
            if b > 0 and label in hm.lengthDict.keys() and \
                    hm.lengthDict[label] is not None:
                x_lim  = axHeat.get_xlim()
                y_lim  = axHeat.get_ylim()
                axHeat.plot(hm.lengthDict[label],
                            np.arange(len(hm.lengthDict[label])),
                            '--', color='black', linewidth=0.5, dashes=(3, 2))
                axHeat.set_xlim(x_lim)
                axHeat.set_ylim(y_lim)
            axHeat.axes.get_xaxis().set_visible(False)
            axHeat.axes.set_xlabel(xAxisLabel)
            axHeat.axes.set_ylabel(label)
            axHeat.axes.set_yticks([])

    ###### add heatmap colorbar ###########
    if showColorbar:
        # this is the case when only the colorbar wants to be
        # printed and nothing else
        if not showHeatmap:
            left = 0.2
            width = 0.6 - (0.20 / figWidth)
            legend = fig.add_axes([left, bottomBorderFrac, width,
                                   heatFracHeight - smallSpacerFrac])
            norm = matplotlib.colors.Normalize(vmin=zMin, vmax=zMax)
            matplotlib.colorbar.ColorbarBase(legend, cmap=colorMap, norm=norm )
        else:
#            left = 1.3 / figWidth
            left = 0.79
            width = 0.05
            legend = fig.add_axes([left, bottomBorderFrac, width,
                                   heatFracHeight - smallSpacerFrac])
            if debug:
                print([left, bottomBorderFrac, width,
                       heatFracHeight - smallSpacerFrac])
            fig.colorbar(img, cax=legend)

    plt.savefig(outFileName, bbox_inches='tight', pdd_inches=0, dpi=200,
                format=image_format)


def flattenMatrix(matrixDict):
    """
    concat and flatten
    """
    matrixFlatten = np.concatenate([x for x in matrixDict.values()]).flatten()
    # nans are removed from the flattened array
    return matrixFlatten[np.isnan(matrixFlatten) == False]

def mergeSmallGroups(matrixDict):
    groupLengths = [len(x) for x in matrixDict.values()]
    minGroupLength = sum(groupLengths) * 0.01

    toMerge = []
    i = 0
    _mergedHeatMapDict = OrderedDict()

    for label, ma in matrixDict.iteritems():
        # merge small groups together
        # otherwise visualization is impaired
        if groupLengths[i] > minGroupLength:
            if len(toMerge):
                toMerge.append(label)
                newLabel = " ".join(toMerge)
                newMa = np.concatenate([matrixDict[item]
                                        for item in toMerge], axis=0)
            else:
                newLabel = label
                newMa = matrixDict[label]

            _mergedHeatMapDict[newLabel] = newMa
            toMerge = []
        else:
            toMerge.append(label)
        i += 1
    if len(toMerge) > 1:
        newLabel = " ".join(toMerge)
        newMa = np.array()
        for item in toMerge:
            newMa = np.concatenate([newMa, matrixDict[item] ])
        _mergedHeatMapDict[newLabel] = newMa

    return _mergedHeatMapDict


def main(args):
    r"""
    >>> import filecmp
    >>> import os
    >>> args = parseArguments(
    ... "-m ../deeptools/test/test_heatmapper/master.mat.gz \
    ... --outFileName /tmp/_test.png".split())
    >>> main(args)
    >>> filecmp.cmp(
    ... '../deeptools/test/test_heatmapper/master.png', '/tmp/_test.png') #may fail if diff version of  matplotlib library is used
    True
    >>> os.remove('/tmp/_test.png')
    >>> args = parseArguments(
    ... "-m ../deeptools/test/test_heatmapper/master.mat.gz \
    ... --outFileName /tmp/_test2.png --regionsLabel uno,dos".split())
    >>> main(args)
    >>> filecmp.cmp(
    ... '../deeptools/test/test_heatmapper/master_relabeled.png',
    ... '/tmp/_test2.png') #may fail because diff matplotlib library was used
    True
    >>> os.remove('/tmp/_test2.png')
    >>> args = parseArguments(
    ... "-m ../deeptools/test/test_heatmapper/master_scale_reg.mat.gz \
    ... --outFileName /tmp/_test3.png".split())
    >>> main(args)
    >>> filecmp.cmp(
    ... '../deeptools/test/test_heatmapper/master_scale_reg.png',
    ... '/tmp/_test3.png') #may fail because diff matplotlib library was used
    True
    >>> os.remove('/tmp/_test3.png')

    """

    hm = heatmapper.heatmapper()
    matrix_file = args.matrixFile.name
    args.matrixFile.close()
    hm.readMatrixFile(matrix_file,
                      default_group_name=args.regionsLabel)

    if args.kmeans is not None:
        k = args.kmeans
        idx = hm.hmcluster(hm.matrixDict[hm.matrixDict.keys()[0]], 
                           k,
                           method='kmeans')

        # split the matrix, region and averages in to clusters
        label = hm.matrixDict.keys()[0]
        _matrixDict = {}
        _regionsDict = {}
        _avgDict = {}
        for cluster in range(k):
            c_label = "cluster {}".format(cluster+1)
            cluster_ids = np.flatnonzero(idx==cluster)
            _matrixDict[c_label] = \
                hm.matrixDict[label][cluster_ids,]
            _regionsDict[c_label] = \
                hm.regionsDict[label][cluster_ids]

        hm.matrixDict = _matrixDict
        hm.regionsDict = _regionsDict

    # check if maybe some group is too small to be plotted
    # this causes a segmentation fault so is better to
    # try catch the problem and return a meaningful output
    total_regions = 0
    for group_matrix in hm.matrixDict.values():
        total_regions += group_matrix.shape[0]

    for label, group_matrix in hm.matrixDict.iteritems():
        if float(group_matrix.shape[0])/total_regions < 5.0/1000:
            print "Group {} contains too few regions {}. It can't "\
                "be plotted. Try removing this group ".format(
                label, 
                group_matrix.shape[0])
            exit(0)

    if len(args.regionsLabel):
        hm.reLabelGroups(args.regionsLabel)

    if args.sortRegions != 'no':
        hm.sortMatrix(sort_using=args.sortUsing,
                      sort_method=args.sortRegions)
    

    if args.outFileNameMatrix:
        hm.saveMatrixValues(args.outFileNameMatrix)

    if args.outFileNameData:
        hm.saveTabulatedValues(args.outFileNameData)

    if args.outFileSortedRegions:
        hm.saveBED(args.outFileSortedRegions)

    plotMatrix(hm,
               args.outFileName,
               args.colorMap, args.missingDataColor, args.plotTitle,
               args.xAxisLabel, args.yAxisLabel, args.regionsLabel,
               args.zMin, args.zMax,
               args.yMin, args.yMax,
               args.averageTypeSummaryPlot,
               args.refPointLabel,
               args.startLabel,
               args.endLabel,
               args.heatmapHeight,
               args.heatmapWidth,
               args.onePlotPerGroup,
               args.whatToShow,
               image_format=args.plotFileFormat)


if __name__ == "__main__":
    args = parseArguments()
    main(args)
