#!/usr/bin/env python
#-*- coding: utf-8 -*-

import random
import sys
import argparse
import os.path
import numpy as np
from matplotlib import use as mplt_use

mplt_use('Agg')
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr

import deeptools.getScorePerBigWigBin as score_bw
from deeptools import parserCommon
from deeptools._version import __version__


def parseArguments(args=None):
    parser = \
        argparse.ArgumentParser(
            formatter_class=argparse.RawDescriptionHelpFormatter,
            description="""

bigwigCorrelate computes the overall similarity between two or more bigWig
files based on coverage means of genomic regions. The correlation analysis
is performed for the entire genome by running the program in 'bins' mode,
or for certain regions only in 'BED-file' mode. Pearson or Spearman analyses
are available to compute correlation coefficients. Results are saved to a
heat map file. Further output files are optional.

detailed help:
  %(prog)s bins -h
  %(prog)s BED-file -h

""",
            epilog='example usages:\n%(prog)s bins '
                   '-b file1.bw file2.bw -o heatmap.png --corMethod pearson\n\n'
                   '%(prog)s BED-file -b file1.bw file2.bw -o heatmap.png --corMethod pearson\n'
                   '--BED selection.bed '
                   ' \n\n',
            conflict_handler='resolve')

    parser.add_argument('--version', action='version',
                        version='%(prog)s {}'.format(__version__))
    subparsers = parser.add_subparsers(
        title="commands",
        dest='command',
        metavar='')

    parent_parser = parserCommon.getParentArgParse(binSize=False)
    #read_options_parser = parserCommon.read_options()
    heatmap_parser = heatmap_options()

    # bins mode options
    bins_mode = subparsers.add_parser(
        'bins',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[bigwigCorrelateArgs(case='bins'),
                 parent_parser,
                 #read_options_parser,
                 heatmap_parser
        ],
        help="The correlation is based on arbitrary bins of similar "
             "size (10k bp by default), which consecutively cover the "
             "entire genome. The only exception is the last bin, which "
             "is regularly smaller. This mode is useful to assess the "
             "overall similarity of bigWig files.",
        add_help=False,
        usage='%(prog)s '
              '-b file1.bw file2.bw '
              '-o heatmap.png --corMethod pearson\n')

    # BED file arguments
    bed_mode = subparsers.add_parser(
        'BED-file',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[bigwigCorrelateArgs(case='BED-file'),
                 parent_parser,
                 #read_options_parser,
                 heatmap_parser
        ],
        help="The user provides a BED file that contains all regions "
             "that should be considered for the correlation analysis. A "
             "common use is to compare ChIP-seq coverages between two "
             "different samples for a set of peak regions.",
        usage='%(prog)s '
              '-b file1.bw file2.bw '
              '-o heatmap.png --corMethod pearson --BED selection.bed\n',
        add_help=False)

    args = parser.parse_args(args)
    #args.extendPairedEnds = False if args.doNotExtendPairedEnds else True

    if args.labels and len(args.bwfiles) != len(args.labels):
        print "The number of labels does not match the number of bigWig files."
        exit(0)
    if not args.labels:
        args.labels = map(lambda x: os.path.basename(x), args.bwfiles)

    return args


def bigwigCorrelateArgs(case='bins'):
    parser = argparse.ArgumentParser(add_help=False)
    required = parser.add_argument_group('Required arguments')

    # define the arguments
    required.add_argument('--bwfiles', '-b',
                        metavar='FILE1 FILE2',
                        help='List of bigWig files separated by spaces.',
                        nargs='+',
                        required=True)

    required.add_argument('--plotFile', '-o',
                        help='File name to save the file containing the heat map '
                            'of the correlation. The given file ending will be used '
                            ' to determine the image format, for example: '
                            'heatmap.pdf will save the heat map in PDF format. '
                            'The available options are: .png, .emf, '
                            '.eps, .pdf and .svg.',
                        type=argparse.FileType('w'),
                        metavar='FILE',
                        required=True)

    required.add_argument('--corMethod', '-c',
                        help="Correlation method. For Pearson correlation only, "
                            "automatic 99.9 percentile filtering is enabled by "
                            "default in order to reduce the influence of extreme "
                            "outliers.",
                        choices=['pearson', 'spearman'],
                        required=True)

    optional = parser.add_argument_group('Optional arguments')

    optional.add_argument("--help", "-h", action="help",
                        help="show this help message and exit")
    optional.add_argument('--labels', '-l',
                        metavar='sample1 sample2',
                        help='User defined labels instead of default labels from '
                            'file names. '
                            'Multiple labels have to be separated by space, e.g. '
                            '--labels sample1 sample2 sample3',
                        nargs='+')
    optional.add_argument('--filterPercentile',
                        metavar='FLOAT',
                        help='Percentile used to filter out extreme outliers. '
                            'If not specified, it is automatically set to 99.9 in analyses '
                            'using Pearson correlation! This means that values above that '
                            'threshold, which consistently occur in all datasets, will not be '
                            'taken into account for the correlation analysis. This behavior can '
                            'be overridden by a user specified value from within the 0.0 to 100.0 '
                            'range.',
                        default=None,
                        type=float)
    if case == 'bins':
        optional.add_argument('--binSize', '-bs',
                        metavar='INT',
                        help='Length in base pairs for a window used '
                            'to sample the genome.',
                        default=10000,
                        type=int)

        required.add_argument('--BED',
                        help=argparse.SUPPRESS,
                        default=None)
    else:
        optional.add_argument('--binSize', '-bs',
                        help=argparse.SUPPRESS,
                        default=10000,
                        type=int)

        required.add_argument('--BED',
                        help='Limits the correlation analysis to '
                            'the regions specified in this file.',
                        metavar='bedfile',
                        type=argparse.FileType('r'),
                        required=True)

    optional.add_argument('--includeZeros',
                        help='Genomic regions that have zero values only '
                            'are included. The default behavior is to ignore these '
                            'regions. ',
                        action='store_true',
                        required=False)

    group = parser.add_argument_group('Output optional options')

    group.add_argument('--outFileCorMatrix',
                        help='Save correlation matrix to file.',
                        metavar='FILE',
                        type=argparse.FileType('w'))

    group.add_argument('--outRawCounts',
                        help='Save raw counts (coverages) to file.',
                        metavar='FILE',
                        type=argparse.FileType('w'))

    group.add_argument('--plotFileFormat',
                        metavar='FILETYPE',
                        help='Image format type. If given, this option '
                            'overrides the image format based on the plotFile '
                            'ending. The available options are: png, emf, '
                            'eps, pdf and svg.',
                        choices=['png', 'pdf', 'svg', 'eps', 'emf'])

    return parser


def heatmap_options():
    """
    Options for generating the correlation heat map.
    """
    parser = argparse.ArgumentParser(add_help=False)
    heatmap = parser.add_argument_group('Heat map options')

    heatmap.add_argument('--zMin', '-min',
                        #metavar='',
                        default=None,
                        help='Minimum value for the heat map intensities. '
                            'If not specified the value is set automatically')
    heatmap.add_argument('--zMax', '-max',
                        #metavar='',
                        default=None,
                        help='Maximum value for the heat map intensities.'
                            'If not specified the value is set automatically')

    from matplotlib import cm

    color_options = "', '".join([m for m in cm.datad
                                 if not m.endswith('_r')])

    heatmap.add_argument(
        '--colorMap', default='Reds',
        #metavar='',
        help='Color map to use for the heat map. Available values can be '
             'seen here: '
             'http://www.astro.lsa.umich.edu/~msshin/science/code/'
             'matplotlib_cm/ The available options are: \'' +
             color_options + '\'')

    return parser


def plotCorrelation(corr_matrix, labels, plotFileName, vmax=None,
                    vmin=None, colormap='Reds', image_format=None):
    import scipy.cluster.hierarchy as sch

    num_rows = corr_matrix.shape[0]

    # set the minimum and maximum values
    if vmax is None:
        vmax = 1
    if vmin is None:
        vmin = 0 if corr_matrix.min() >= 0 else -1

    # Compute and plot dendrogram.
    fig = plt.figure(figsize=(10.5, 9.5))
    axdendro = fig.add_axes([0.02, 0.1, 0.1, 0.7])
    axdendro.set_axis_off()
    y_var = sch.linkage(corr_matrix, method='complete')
    z_var = sch.dendrogram(y_var, orientation='right',
                           link_color_func=lambda k: 'black')
    axdendro.set_xticks([])
    axdendro.set_yticks([])
    cmap = plt.get_cmap(colormap)

    # this line simply makes a new cmap, based on the original
    # colormap that goes from 0.0 to 0.9
    # This is done to avoid colors that
    # are too dark at the end of the range that do not offer
    # a good contrast between the correlation numbers that are
    # plotted on black.
    cmap = cmap.from_list(colormap + "clipped", cmap(np.linspace(0, 0.9, 10)))
    # Plot distance matrix.
    axmatrix = fig.add_axes([0.13, 0.1, 0.6, 0.7])
    index = z_var['leaves']
    corr_matrix = corr_matrix[index, :]
    corr_matrix = corr_matrix[:, index]
    img_mat = axmatrix.matshow(corr_matrix, aspect='equal', origin='lower',
                               cmap=cmap, extent=(0, num_rows, 0, num_rows),
                               vmax=vmax, vmin=vmin)
    axmatrix.yaxis.tick_right()
    axmatrix.set_yticks(np.arange(corr_matrix.shape[0]) + 0.5)
    axmatrix.set_yticklabels(np.array(labels).astype('str')[index],
                             fontsize=14)

    axmatrix.set_xticks(np.arange(corr_matrix.shape[0]) + 0.5)
    axmatrix.set_xticklabels(np.array(labels).astype('str')[index],
                             fontsize=14,
                             rotation=45,
                             ha='left')

    #    axmatrix.set_xticks([])
    # Plot colorbar.
    axcolor = fig.add_axes([0.13, 0.065, 0.6, 0.02])
    plt.colorbar(img_mat, cax=axcolor, orientation='horizontal')
    for row in range(num_rows):
        for col in range(num_rows):
            axmatrix.text(row + 0.5, col + 0.5,
                          "{:.2f}".format(corr_matrix[row, col]),
                          ha='center', va='center')

    fig.savefig(plotFileName, format=image_format)


def main(args):
    """
    1. get read counts at different positions either
    all of same length or from genomic regions from the BED file

    2. compute  correlation

    """
    if len(args.bwfiles) < 2:
        print "Please input at least two bigWig (.bw) files to compare"
        exit(1)

    if args.includeZeros:
        skip_zeros = False
    else:
        skip_zeros = True

    # default for pearson is 99.9 percentile filtering
    if args.filterPercentile and not 0.0 <= args.filterPercentile <= 100.0:
        sys.stderr.write("ERROR! Invalid filter percentile value: {}".format(
            args.filterPercentile))
        exit(1)
    if not args.filterPercentile and args.corMethod == "pearson":
        args.filterPercentile = 99.9
        print "Automatic {} percentile filtering enabled before Pearson correlation.".format(
            args.filterPercentile)

    if args.colorMap:
        try:
            plt.get_cmap(args.colorMap)
        except ValueError as error:
            sys.stderr.write(
                "A problem was found. Message: {}\n".format(error))
            exit()

    if 'BED' in args:
        bed_regions = args.BED
    else:
        bed_regions = None

    num_reads_per_bin = score_bw.getScorePerBin(
        args.bwfiles,
        args.binSize,
        numberOfProcessors=args.numberOfProcessors,
        skipZeros=skip_zeros,
        verbose=args.verbose,
        region=args.region,
        bedFile=bed_regions)

    if args.outRawCounts:
        args.outRawCounts.write("'" + "'\t'".join(args.labels) + "'\n")
        fmt = "\t".join(np.repeat('%d', num_reads_per_bin.shape[1])) + "\n"
        for row in num_reads_per_bin:
            args.outRawCounts.write(fmt % tuple(row))

    # remove outliers, which will spoil pearson correlation.
    # for each data set (cols) the values above the 99.9
    # percentile are identified. Such values are removed from
    # the num_reads_per_bin matrix if they occur in the
    # same rows
    if args.filterPercentile:
        unfiltered = len(num_reads_per_bin)
        to_remove = None
        for col in num_reads_per_bin.T:
            outliers = np.flatnonzero(
                col > np.percentile(col, args.filterPercentile))
            if to_remove is None:
                to_remove = set(outliers)
            else:
                to_remove = to_remove.intersection(outliers)
        if len(to_remove):
            to_keep = [x for x in range(num_reads_per_bin.shape[0])
                       if x not in to_remove]
            num_reads_per_bin = num_reads_per_bin[to_keep, :]
            if args.verbose:
                print "{} percentile filtering (total/filtered/left): {}/{}/{}".format(
                    args.filterPercentile, unfiltered,
                    unfiltered - len(to_keep), len(to_keep))

    # num_reads_per_bin: rows correspond to bins, cols to samples
    num_files = len(args.bwfiles)
    #initialize correlation matrix
    corr_matrix = np.zeros((num_files, num_files), dtype='float')
    options = {'spearman': spearmanr,
               'pearson': pearsonr}
    # do an all vs all correlation using the
    # indices of the upper triangle
    rows, cols = np.triu_indices(num_files)
    for index in xrange(len(rows)):
        row = rows[index]
        col = cols[index]
        corr_matrix[row, col] = options[args.corMethod](
            num_reads_per_bin[:, row],
            num_reads_per_bin[:, col])[0]
    # make the matrix symmetric
    corr_matrix = corr_matrix + np.triu(corr_matrix, 1).T

    if args.outFileCorMatrix:
        args.outFileCorMatrix.write("\t'" + "'\t'".join(args.labels) + "'\n")
        fmt = "\t".join(np.repeat('%.4f', num_reads_per_bin.shape[1])) + "\n"
        i = 0
        for row in corr_matrix:
            args.outFileCorMatrix.write(
                "'%s'\t" % args.labels[i] + fmt % tuple(row))
            i += 1

    plot_file_name = args.plotFile.name
    args.plotFile.close()
    plotCorrelation(corr_matrix,
                    args.labels,
                    plot_file_name,
                    args.zMax,
                    args.zMin,
                    args.colorMap,
                    image_format=args.plotFileFormat)


if __name__ == "__main__":
    ARGS = parseArguments()
    main(ARGS)
