#!/usr/bin/env python
#-*- coding: utf-8 -*-

import argparse
import numpy as np
from matplotlib import use as mplt_use
mplt_use('Agg')
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr

import deeptools.countReadsPerBin as countR
from deeptools import parserCommon
from deeptools._version import __version__


def parseArguments(args=None):
    parser = \
        argparse.ArgumentParser(
            formatter_class=argparse.RawDescriptionHelpFormatter,
            description="""

bamCorrelate can be run in two modes: bins and BED-file.

In the bins mode, the correlation is computed based on randomly
sampled bins of equal length. The user has to specify how many bins should
be sampled. This is useful to assess the overall similarity of BAM files.

In the BED-file options, the user supplies a list of genomic regions
in BED format in addition to the list of BAM files. bamCorrelate
subsequently uses the BED file to compare the read coverages for these
regions only. This can be used, for example, to compare the ChIP-seq
coverages of two different samples for a set of peak regions.

For detailed help type:

%(prog)s bins -h or
%(prog)s BED-file -h

""",
            epilog='An example usage is:\n  %(prog)s bins '
            '-b treatment.bam input.bam \ \n'
            '-o correlation.png -f 200 -method pearson\n \n',
            conflict_handler='resolve')

    parser.add_argument('--version', action='version',
                          version='%(prog)s {}'.format(__version__))
    subparsers = parser.add_subparsers(
        title="Commands",
        dest='command',
        metavar='')

    parentParser = parserCommon.getParentArgParse(binSize=False)
    bamParser = parserCommon.bam()

    # bins mode options
    binsMode = subparsers.add_parser(
        'bins',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[bamCorrelateArgs(case='bins'),
                 parentParser,
                 bamParser,
                 ],
        help="In the bins mode, the correlation is computed based on randomly chosen "
        " genomic regions of the same size (= bins). The number of bins to be used has to be specified.",
        add_help=False,
        usage='An example usage is:\n  %(prog)s '
        '-b treatment.bam input.bam '
        '-o correlation.png -f 200 -method pearson\n \n')

    # BED file arguments
    BEDmode = subparsers.add_parser(
        'BED-file',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[bamCorrelateArgs(case='BED-file'),
                 parentParser,
                 bamParser,
                 ],
        help="In the BED-file options, as list of genomic regions in BED "
        "format has to be given. For each region in the BED file the "
        "number of overlapping reads is counted in each of the BAM "
        "files. Then the correlation is computed.",
        usage='An example usage is:\n %(prog)s '
        '--BED genes.bed -b treatment.bam input.bam '
        '-o correlation.png -f 200 -method pearson\n \n',
        add_help=False)

    args = parser.parse_args(args)
    args.extendPairedEnds = False if args.doNotExtendPairedEnds else True

    if args.labels and len(args.bamfiles) != len(args.labels):
        print "The number of does not match the number of bam files."
        exit(0)
    if not args.labels:
        args.labels = args.bamfiles
    return(args)


######
def bamCorrelateArgs(case='bins'):
    parser = argparse.ArgumentParser(add_help=False)
    required = parser.add_argument_group('Required arguments')

    # define the arguments
    required.add_argument('--bamfiles', '-b',
                          metavar = 'list of bam files',
                          help = 'List of indexed bam files separated by space',
                          nargs = '+',
                          required = True)

    required.add_argument('--plotFile', '-o',
                       help='File name to save the file containing the heatmap '
                       'of the correlation. The given file ending will be used '
                       ' to determine the image format, for example: '
                       'correlation.pdf will save the heatmap in pdf format. '
                        'The available options are: .png, .emf, '
                        '.eps, .pdf and .svg.',
                       type=argparse.FileType('w'),
                       metavar = 'file name',
                       required=True)

    required.add_argument('--corMethod', '-c',
                          help = 'correlation method to use',
                          choices = ['pearson', 'spearman'],
                          required=True)

    optional = parser.add_argument_group('Optional arguments')

    optional.add_argument("--help", "-h", action="help",
                          help="show this help message and exit")
    optional.add_argument('--labels', '-l',
                        metavar = '',
                        help = 'List of labels to use in the image. '
                        'If no labels are given, the file names will be used instead. '
                        'Separate the labels by space, e.g. --labels sample1 sample2 sample3',
                        nargs = '+')
    if case == 'bins':
        optional.add_argument('--binSize', '-bs',
                              metavar='',
                              help = 'Length in base pairs for a window used '
                              'to sample the genome.',
                              default = 10000,
                              type = int)

        optional.add_argument('--numberOfSamples', '-n',
                              metavar='',
                              help = 'Number of random genome places to count reads for '
                              'the correlation computation.',
                              default=int(1e6),
                              type=int)

        required.add_argument('--BED',
                              help=argparse.SUPPRESS,
                              default=None)
    else:
        optional.add_argument('--binSize', '-bs',
                            help=argparse.SUPPRESS,
                            default=10000,
                            type=int)

        optional.add_argument('--numberOfSamples', '-n',
                            help=argparse.SUPPRESS,
                            default=int(1e5),
                            type=int)

        required.add_argument('--BED', 
                           help = '*If* the correlation of read counts should be limited to '
                               'certain regions, a BED file can be given. If this '
                               'is the case, then the correlation is computed for '
                               'the number of reads that overlap such regions and '
                               'the values set for binSize and numberOfSamples are ' 
                               'ignored.',
                           metavar='FILE',
                           type=argparse.FileType('r'),
                           required=True)

    optional.add_argument('--includeZeros',
                        help = 'If set, then zero counts that happen for '
                        '*all* bam files given are included. The default '
                        'behavior is to ignore those cases ',
                        action = 'store_true',
                        required = False)

    heatmap = parser.add_argument_group('Heatmap options')

    heatmap.add_argument('--zMin', '-min',
                         metavar='',
                         default=None,
                         help='Minimum value for the heatmap intensities. '
                         'If not specified the value is set automatically')
    heatmap.add_argument('--zMax', '-max',
                         metavar='',
                         default=None,
                         help='Maximum value for the heatmap intensities.'
                         'If not specified the value is set automatically')

    from matplotlib import cm
    color_options = "', '".join([m for m in cm.datad
                                 if not m.endswith('_r')])

    heatmap.add_argument(
        '--colorMap', default='Reds',
        metavar='',
        help='Color map to use for the heatmap. Available values can be '
        'seen here: '
        'http://www.astro.lsa.umich.edu/~msshin/science/code/'
        'matplotlib_cm/ The available options are: \'' +
        color_options + '\'')

        ####
    group = parser.add_argument_group('Output optional options')

    group.add_argument('--outFileCorMatrix',
                        help = 'Output file name for the correlation matrix.',
                        metavar = '',
                        type=argparse.FileType('w'))

    group.add_argument('--outRawCounts',
                        help = 'Output file name to save the bin counts',
                        metavar = '',
                        type=argparse.FileType('w'))

    group.add_argument('--plotFileFormat',
                       metavar='',
                       help='image format type. If given, this option overrides the '
                       'image format based on the plotFile ending. '
                       'The available options are: png, emf, '
                        'eps, pdf and svg.',
                       choices=['png','pdf', 'svg','eps', 'emf'])
    return parser

def plotCorrelation(corr_matrix, labels, plotFileName, vmax=None,
                    vmin=None, colormap='Reds', image_format=None):
    import scipy.cluster.hierarchy as sch
    M = corr_matrix.shape[0]

    # set the minimum and maximum values
    if vmax is None:
        vmax = 1
    if vmin is None:
        vmin = 0 if corr_matrix.min()>=0 else -1

    # Compute and plot dendrogram.
    fig = plt.figure(figsize=(10.5,9.5))
    axdendro = fig.add_axes([0.02,0.1,0.1,0.7])
    axdendro.set_axis_off()
    Y = sch.linkage(corr_matrix, method='complete')
#    Y = sch.linkage(corr_matrix, method='centroid')
    Z = sch.dendrogram(Y, orientation='right', link_color_func=lambda k: 'black')
    axdendro.set_xticks([])
    axdendro.set_yticks([])
    cmap = plt.get_cmap(colormap)
    # this line simply makes a new cmap, based on the original
    # colormap that goes from 0.0 to 0.9
    # This is done to avoid colors that
    # are too dark at the end of the range that do not offer
    # a good contrast between the correlation numbers that are 
    # plotted on black. 
    # 
    cmap = cmap.from_list(colormap + "clipped", cmap([0.0, 0.8]))
    # Plot distance matrix.
    axmatrix = fig.add_axes([0.13,0.1,0.6,0.7])
    index = Z['leaves']
    corr_matrix = corr_matrix[index,:]
    corr_matrix = corr_matrix[:,index]
    im = axmatrix.matshow(corr_matrix, aspect='equal', origin='lower', 
                          cmap=cmap, extent=(0, M, 0, M), vmax=vmax, vmin=vmin)
    axmatrix.yaxis.tick_right()
    axmatrix.set_yticks(np.arange(corr_matrix.shape[0])+0.5)
    axmatrix.set_yticklabels(np.array(labels).astype('str')[index],
                             fontsize=14)

    axmatrix.set_xticks(np.arange(corr_matrix.shape[0])+0.5)
    axmatrix.set_xticklabels(np.array(labels).astype('str')[index],
                             fontsize=14,
                             rotation=45,
                             ha='left')

#    axmatrix.set_xticks([])
    # Plot colorbar.
    axcolor = fig.add_axes([0.13,0.065,0.6, 0.02])
    plt.colorbar(im, cax=axcolor,  orientation='horizontal')
    for row in range(M):
        for col in range(M):
            axmatrix.text(row+0.5, col+0.5, "{:.2f}".format(corr_matrix[row,col]),
                          ha='center', va='center')

    fig.savefig(plotFileName, format=image_format)

def main(args):
    """
    1. get read counts at different positions either 
    all of same length or from genomic regions from the BED file
    
    2. compute  correlation

    """
    if len(args.bamfiles) < 2:
        print "Please input at least two bam files to compare"
        exit(1)

    if args.includeZeros:
        skipZeros = False
    else:
        skipZeros = True

    if args.colorMap:
        try:
            cmap = plt.get_cmap(args.colorMap)
        except ValueError as e:
            print e
            exit()

    if 'BED' in args:
        BED_regions = args.BED
    else:
        BED_regions = None

    num_reads_per_bin = countR.getNumReadsPerBin(args.bamfiles, 
                                                 args.binSize, args.numberOfSamples, 
                                                 args.fragmentLength, 
                                                 numberOfProcessors=args.numberOfProcessors, 
                                                 skipZeros=skipZeros,
                                                 verbose=args.verbose,
                                                 region=args.region,
                                                 bedFile=BED_regions)

    if args.outRawCounts:
        args.outRawCounts.write("'" + "'\t'".join(args.labels) + "'\n" )
        fmt = "\t".join(np.repeat('%d', num_reads_per_bin.shape[1])) + "\n"
        for row in num_reads_per_bin:
            args.outRawCounts.write(fmt % tuple(row))


    # remove outliers, which will spoil pearson correlation.
    # for each data set (cols) the values above the 99.9 
    # percentile are identified. Such values are removed from 
    # the num_reads_per_bin matrix if they occour in the
    # same rows
    to_remove = None
    for col in num_reads_per_bin.T:
        outliers = np.flatnonzero(col > np.percentile(col, 99.9))
        if to_remove is None:
            to_remove = set(outliers)
        else:
            to_remove = to_remove.intersection(outliers)
    if len(to_remove):
        to_keep = [x for x in range(num_reads_per_bin.shape[0])
                   if x not in to_remove]
        num_reads_per_bin = num_reads_per_bin[to_keep,:]
    

    # num_reads_per_bin: rows correspond to  bins, cols to  samples
    M = len(args.bamfiles)
    #initialize correlation matrix
    corr_matrix = np.zeros( (M,M), dtype='float')
    options = {'spearman': spearmanr,
               'pearson': pearsonr}
    # do an all vs all correlation using the
    # indices of the upper triangle
    rows, cols = np.triu_indices(M)
    for index in xrange(len(rows)):
        row = rows[index]
        col = cols[index]
        corr_matrix[row, col] = options[args.corMethod](num_reads_per_bin[:,row],
                                                        num_reads_per_bin[:,col])[0]
    # make the matrix symetric
    corr_matrix = corr_matrix + np.triu(corr_matrix, 1).T
    
    if args.outFileCorMatrix:
        args.outFileCorMatrix.write("\t'"+ "'\t'".join(args.labels) + "'\n" )
        fmt = "\t".join(np.repeat('%.4f', num_reads_per_bin.shape[1])) + "\n"
        i = 0
        for row in corr_matrix:
            args.outFileCorMatrix.write("'%s'\t" % args.labels[i] + fmt % tuple(row))
            i += 1    

    plotFileName = args.plotFile.name
    args.plotFile.close()
    plotCorrelation(corr_matrix, args.labels, plotFileName, args.zMax, args.zMin, 
                    args.colorMap, image_format=args.plotFileFormat)

if __name__ == "__main__":
    args = parseArguments()
    main(args)
