#!/usr/bin/env python
#-*- coding: utf-8 -*-

import os
import random
import sys
import argparse
import numpy as np
from matplotlib import use as mplt_use
mplt_use('Agg')
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr

import deeptools.countReadsPerBin as countR
from deeptools import parserCommon
from deeptools._version import __version__


def parseArguments(args=None):
    parser = \
        argparse.ArgumentParser(
            formatter_class=argparse.RawDescriptionHelpFormatter,
            description="""

bamCorrelate computes the overall similarity between two or more BAM
files based on read coverage (number of reads) within genomic regions.
The correlation analysis is performed for the entire genome by running
the program in 'bins' mode, or for certain user selected regions in 'BED-file'
mode. Pearson or Spearman methods are available to compute correlation
coefficients. Results are saved into a heat map image and further output files
are optional. NOTE: bamCorrelate usually takes a long time to finish, thus it 
is adivsable to run the tool for a tiny region (using the --region option)
to adjust plotting parameters like colors and labels before running the
whole computation.

detailed help:

  %(prog)s bins -h
  %(prog)s BED-file -h

""",
           epilog='example usages:\n%(prog)s bins '
           '-b file1.bam file2.bam -o heatmap.png --corMethod pearson\n\n'
           '%(prog)s BED-file -b file1.bam file2.bam -o heatmap.png --corMethod spearman\n'
           '--BED selection.bed '
           ' \n\n',
            conflict_handler='resolve')

    parser.add_argument('--version', action='version',
                          version='%(prog)s {}'.format(__version__))
    subparsers = parser.add_subparsers(
        title="commands",
        dest='command',
        metavar='')

    parent_parser = parserCommon.getParentArgParse(binSize=False)
    read_options_parser = parserCommon.read_options()
    heatmap_parser = heatmap_options()

    # bins mode options
    bins_mode = subparsers.add_parser(
        'bins',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[bamCorrelateArgs(case='bins'),
                 parent_parser, read_options_parser,
                 heatmap_parser
                 ],
        help="The correlation is based on read coverage over "
             "consecutive bins of equal "
             "size (10k bp by default). This mode is useful to assess the "
             "overall similarity of BAM files. The bin size and the "
             "distance between bins can be adjusted.",
        add_help=False,
        usage='%(prog)s '
              '-b file1.bw file2.bw '
              '-o heatmap.png --corMethod spearman\n')

    # BED file arguments
    bed_mode = subparsers.add_parser(
        'BED-file',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[bamCorrelateArgs(case='BED-file'),
                 parent_parser, read_options_parser,
                 heatmap_parser
                 ],
        help="The user provides a BED file that contains all regions "
             "that should be considered for the correlation analysis. A "
             "common use is to compare ChIP-seq coverages between two "
             "different samples for a set of peak regions.",
        usage='%(prog)s '
              '-b file1.bw file2.bw '
              '-o heatmap.png --corMethod pearson --BED selection.bed\n',
        add_help=False)

    args = parser.parse_args(args)
    args.extendPairedEnds = False if args.doNotExtendPairedEnds else True

    if args.labels and len(args.bamfiles) != len(args.labels):
        print "The number of does not match the number of bam files."
        exit(0)
    if not args.labels:
        args.labels = map(lambda x: os.path.basename(x), args.bamfiles)

    return args


def bamCorrelateArgs(case='bins'):
    parser = argparse.ArgumentParser(add_help=False)
    required = parser.add_argument_group('Required arguments')

    # define the arguments
    required.add_argument('--bamfiles', '-b',
                        metavar='FILE1 FILE2',
                        help='List of indexed bam files separated by spaces.',
                        nargs='+',
                        required=True)

    required.add_argument('--plotFile', '-o',
                        help='File name to save the file containing the heatmap '
                        'of the correlation. The given file ending will be used '
                        ' to determine the image format, for example: '
                        'heatmap.pdf will save the heatmap in PDF format. '
                        'The available options are: .png, .emf, '
                        '.eps, .pdf and .svg.',
                        type=argparse.FileType('w'),
                        metavar='FILE',
                        required=True)

    required.add_argument('--corMethod', '-c',
                          help="Correlation method.",
                          choices=['spearman', 'pearson'],
                          required=True)

    optional = parser.add_argument_group('Optional arguments')

    optional.add_argument("--help", "-h", action="help",
                        help="show this help message and exit")
    optional.add_argument('--labels', '-l',
                        metavar='sample1 sample2',
                        help='User defined labels instead of default labels from '
                            'file names. '
                            'Multiple labels have to be separated by space, e.g. '
                            '--labels sample1 sample2 sample3',
                        nargs='+')

    if case == 'bins':
        optional.add_argument('--binSize', '-bs',
                        metavar='INT',
                        help='Length in base pairs for a window used '
                            'to sample the genome.',
                        default=10000,
                        type=int)

        optional.add_argument('--distanceBetweenBins', '-n',
                              metavar='INT',
                              help='By default, bamCorrelate considers consecutive '
                              'bins of the specified --binSize. However, to '
                              'reduce the computation time, a larger distance '
                              'between bins can by given. Larger distances '
                              'result in less bins being considered.',
                        default=0,
                        type=int)

        optional.add_argument(
            '--doNotRemoveOutliers',
            help='By default, bins with very large counts are removed. '
            'By setting this option outliers will not be removed. '
            'Bins with abnormally high reads counts artificially increase '
            'pearson correlation; that\'s why, by default, bamCorrelate tries '
            'to remove outliers using the median absolute deviation (MAD) '
            'method applying a threshold of 200 to only consider extremely '
            'large deviations from the median. ENCODE blacklist page '
            '(https://sites.google.com/site/anshulkundaje/projects/blacklists) '
            'contains useful information about regions with unusually high counts.',
            action='store_true')

        required.add_argument('--BED',
                        help=argparse.SUPPRESS,
                        default=None)
    else:
        optional.add_argument('--binSize', '-bs',
                        help=argparse.SUPPRESS,
                        default=10000,
                        type=int)

        optional.add_argument('--distanceBetweenBins', '-n',
                              help=argparse.SUPPRESS,
                              metavar='INT',
                              default=0,
                              type=int)

        required.add_argument('--BED',
                        help='Limits the correlation analysis to '
                             'the regions specified in this file.',
                        metavar='bedfile',
                        type=argparse.FileType('r'),
                        required=True)

    optional.add_argument('--includeZeros',
                        help='Genomic regions that have zero values only '
                          'are included. The default behavior is to '
                          'ignore these regions.',
                        action='store_true',
                        required=False)

    group = parser.add_argument_group('Output optional options')

    group.add_argument('--outFileCorMatrix',
                        help='Save correlation matrix to file.',
                        metavar='FILE',
                        type=argparse.FileType('w'))

    group.add_argument('--outRawCounts',
                        help='Save raw counts (coverages) to file.',
                        metavar='FILE',
                        type=argparse.FileType('w'))

    group.add_argument('--plotFileFormat',
                        metavar='FILETYPE',
                        help='Image format type. If given, this option '
                            'overrides the image format based on the plotFile '
                            'ending. The available options are: png, emf, '
                            'eps, pdf and svg.',
                        choices=['png', 'pdf', 'svg', 'eps', 'emf'])

    return parser


def heatmap_options():
    """
    Options for generating the correlation heat map
    """
    parser = argparse.ArgumentParser(add_help=False)
    heatmap = parser.add_argument_group('Heat map options')

    heatmap.add_argument('--zMin', '-min',
                        #metavar='',
                        default=None,
                        help='Minimum value for the heat map intensities. '
                            'If not specified the value is set automatically')
    heatmap.add_argument('--zMax', '-max',
                        #metavar='',
                        default=None,
                        help='Maximum value for the heat map intensities.'
                            'If not specified the value is set automatically')

    from matplotlib import cm
    color_options = "', '".join([m for m in cm.datad
                                 if not m.endswith('_r')])

    heatmap.add_argument(
        '--colorMap', default='Reds',
        metavar='',
        help='Color map to use for the heatmap. Available values can be '
            'seen here: '
            'http://www.astro.lsa.umich.edu/~msshin/science/code/'
            'matplotlib_cm/ The available options are: \'' +
        color_options + '\'')

    return parser


def getOutlierIndices(data, max_deviation=200):
    """
    The method is based on the median absolute deviation. See
    Boris Iglewicz and David Hoaglin (1993),
    "Volume 16: How to Detect and Handle Outliers",
    The ASQC Basic References in Quality Control:
    Statistical Techniques, Edward F. Mykytka, Ph.D., Editor.

    returns the list, without the outliers

    The max_deviation=200 is like selecting a z-score
    larger than 200, just that it is based on the median
    and the median absolute deviation instead of the 
    mean and the standard deviation.
    """
    median = np.median(data)
    b_value = 1.4826 # value set for a normal distribution
    mad = b_value * np.median(np.abs(data-median))
    outliers = []
    if mad > 0:
        deviation = abs(data - median) / mad
        """
        outliers = data[deviation > max_deviation]
        print "outliers removed {}".format(len(outliers))
        print outliers
        """
        outliers = np.flatnonzero(deviation > max_deviation)
    return outliers

def plotCorrelation(corr_matrix, labels, plotFileName, vmax=None,
                    vmin=None, colormap='Reds', image_format=None):
    import scipy.cluster.hierarchy as sch
    from matplotlib import rcParams

    num_rows = corr_matrix.shape[0]

    # set a font size according to figure length
    if num_rows < 6:
        font_size = 14
    elif num_rows > 40:
        font_size = 5
    else:
        font_size = int(15 - 0.25*num_rows)
    rcParams.update({'font.size': font_size})
    # set the minimum and maximum values
    if vmax is None:
        vmax = 1
    if vmin is None:
        vmin = 0 if corr_matrix.min() >= 0 else -1

    # Compute and plot dendrogram.
    fig = plt.figure(figsize=(10.5, 9.5))
    axdendro = fig.add_axes([0.02, 0.12, 0.1, 0.66])
    axdendro.set_axis_off()
    y_var = sch.linkage(corr_matrix, method='complete')
    z_var = sch.dendrogram(y_var, orientation='right',
                           link_color_func=lambda k: 'black')
    axdendro.set_xticks([])
    axdendro.set_yticks([])
    cmap = plt.get_cmap(colormap)

    # this line simply makes a new cmap, based on the original
    # colormap that goes from 0.0 to 0.9
    # This is done to avoid colors that
    # are too dark at the end of the range that do not offer
    # a good contrast between the correlation numbers that are
    # plotted on black.
    cmap = cmap.from_list(colormap + "clipped", cmap(np.linspace(0, 0.9, 10)))
    # Plot distance matrix.
    axmatrix = fig.add_axes([0.13, 0.1, 0.6, 0.7])
    index = z_var['leaves']
    corr_matrix = corr_matrix[index, :]
    corr_matrix = corr_matrix[:, index]
    img_mat = axmatrix.matshow(corr_matrix, aspect='equal', origin='lower',
                               cmap=cmap, extent=(0, num_rows, 0, num_rows),
                               vmax=vmax, vmin=vmin)
    axmatrix.yaxis.tick_right()
    axmatrix.set_yticks(np.arange(corr_matrix.shape[0])+0.5)
    axmatrix.set_yticklabels(np.array(labels).astype('str')[index])

    axmatrix.set_xticks(np.arange(corr_matrix.shape[0])+0.5)
    axmatrix.set_xticklabels(np.array(labels).astype('str')[index],
                             rotation=45,
                             ha='left')

    axmatrix.tick_params(\
        axis='x',
        which='both',
        bottom='off',
        top='off')

    axmatrix.tick_params(\
        axis='y',
        which='both',
        left='off',
        right='off')

    #    axmatrix.set_xticks([])
    # Plot colorbar.
    axcolor = fig.add_axes([0.13, 0.065, 0.6, 0.02])
    plt.colorbar(img_mat, cax=axcolor, orientation='horizontal')
    for row in range(num_rows):
        for col in range(num_rows):
            axmatrix.text(row+0.5, col+0.5,
                          "{:.2f}".format(corr_matrix[row, col]),
                          ha='center', va='center')

    fig.savefig(plotFileName, format=image_format)


def main(args):
    """
    1. get read counts at different positions either
    all of same length or from genomic regions from the BED file

    2. compute  correlation

    """
    if len(args.bamfiles) < 2:
        print "Please input at least two bam files to compare"
        exit(1)

    if args.includeZeros:
        skip_zeros = False
    else:
        skip_zeros = True

    if args.colorMap:
        try:
            plt.get_cmap(args.colorMap)
        except ValueError as error:
            sys.stderr.write(
                "A problem was found. Message: {}\n".format(error))
            exit()

    if 'BED' in args:
        bed_regions = args.BED
    else:
        bed_regions = None

    stepsize = args.binSize + args.distanceBetweenBins
    num_reads_per_bin = countR.getNumReadsPerBin(
        args.bamfiles,
        args.binSize,
        0,
        args.fragmentLength,
        numberOfProcessors=args.numberOfProcessors,
        skipZeros=skip_zeros,
        verbose=args.verbose,
        region=args.region,
        bedFile=bed_regions,
        extendPairedEnds=args.extendPairedEnds,
        minMappingQuality=args.minMappingQuality,
        ignoreDuplicates=args.ignoreDuplicates,
        stepSize=stepsize)

    sys.stderr.write("Number of non zero bins "
                     "found: {}\n".format(num_reads_per_bin.shape[0]))

    if num_reads_per_bin.shape[0] < 2:
        exit("ERROR: too few non zero bins found.\n"
             "If using --region please check that this "
             "region is covered by reads.\n")
        

    if args.outRawCounts:
        args.outRawCounts.write("'" + "'\t'".join(args.labels) + "'\n")
        fmt = "\t".join(np.repeat('%d', num_reads_per_bin.shape[1])) + "\n"
        for row in num_reads_per_bin:
            args.outRawCounts.write(fmt % tuple(row))


    # remove outliers, otherwise outliers will produce a very
    # high pearson correlation. Only for the bins option
    if 'bins' in args and args.doNotRemoveOutliers is False:
        unfiltered = len(num_reads_per_bin)
        to_remove = None
        for col in num_reads_per_bin.T:
            # get the outliers per colum using the median absolute
            # deviation method
            outliers = getOutlierIndices(col)
            if to_remove is None:
                to_remove = set(outliers)
            else:
                # only set to remove those bins in which
                # the outliers are present in all cases (colums)
                # that's why the intersection is used
                to_remove = to_remove.intersection(outliers)
        if len(to_remove):
            to_keep = [x for x in range(num_reads_per_bin.shape[0])
                       if x not in to_remove]
            num_reads_per_bin = num_reads_per_bin[to_keep, :]
            if args.verbose:
                sys.stderr.write(
                    "total/filtered/left: "
                    "{}/{}/{}\n".format(unfiltered,
                                        unfiltered - len(to_keep), len(to_keep)))

    # num_reads_per_bin: rows correspond to  bins, cols to  samples
    num_files = len(args.bamfiles)
    #initialize correlation matrix
    corr_matrix = np.zeros((num_files, num_files), dtype='float')
    options = {'spearman': spearmanr,
               'pearson': pearsonr}
    # do an all vs all correlation using the
    # indices of the upper triangle
    rows, cols = np.triu_indices(num_files)
    for index in xrange(len(rows)):
        row = rows[index]
        col = cols[index]
        corr_matrix[row, col] = options[args.corMethod](
            num_reads_per_bin[:, row],
            num_reads_per_bin[:, col])[0]
    # make the matrix symmetric
    corr_matrix = corr_matrix + np.triu(corr_matrix, 1).T

    if args.outFileCorMatrix:
        args.outFileCorMatrix.write("\t'" + "'\t'".join(args.labels) + "'\n")
        fmt = "\t".join(np.repeat('%.4f', num_reads_per_bin.shape[1])) + "\n"
        i = 0
        for row in corr_matrix:
            args.outFileCorMatrix.write(
                "'%s'\t" % args.labels[i] + fmt % tuple(row))
            i += 1

    plot_file_name = args.plotFile.name
    args.plotFile.close()
    plotCorrelation(corr_matrix,
                    args.labels,
                    plot_file_name,
                    args.zMax,
                    args.zMin,
                    args.colorMap,
                    image_format=args.plotFileFormat)

if __name__ == "__main__":
    ARGS = parseArguments()
    main(ARGS)
