#!/usr/bin/env python
#-*- coding: utf-8 -*-

import argparse  # to parse command line arguments
import numpy as np

# my packages
from deeptools import writeBedGraph
from deeptools.SES_scaleFactor import estimateScaleFactor
from deeptools import parserCommon
from deeptools import bamHandler
from deeptools.getRatio import getRatio

debug = 0


def parseArguments(args=None):
    parentParser = parserCommon.getParentArgParse()
    requiredArgs = getRequiredArgs()
    optionalArgs = getOptionalArgs()
    bamParser = parserCommon.bam()
    outputParser = parserCommon.output()
    parser = argparse.ArgumentParser(
        parents=[requiredArgs, outputParser, optionalArgs, parentParser,
                 bamParser],
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description='This tool compares two BAM files based on the number of '
        'mapped reads. To compare the BAM files, the genome is partitioned '
        'into bins of equal size, then the number of reads found in each BAM '
        'file is counted for such bins and finally a summarizing value is '
        'reported. This value can be the ratio of the number of reads per '
        'bin, the log2 of the ratio or the difference. This tool can '
        'normalize the number of reads on each BAM file using the SES method '
        'proposed by Diaz et al. (2012). "Normalization, bias correction, and '
        'peak calling for ChIP-seq". Statistical applications in genetics '
        'and molecular biology, 11(3). Normalization based on read counts '
        'is also available. The output is either a bedgraph or a bigwig file '
        'containing the bin location and the resulting comparison values. By '
        'default, if reads are mated, the fragment length reported in the BAM '
        'file is used. In the case of paired-end mapping each read mate '
        'is treated independently to avoid a bias when a mixture of concordant '
        'and discordant pairs is present. This means that *each end* will '
        'be extended to match the fragment length.',

        usage='An example usage is:\n %(prog)s '
        '-b1 treatment.bam -b2 control.bam -o log2ratio.bw',

        add_help=False)

    args = parser.parse_args(args)
    args.extendPairedEnds = False if args.doNotExtendPairedEnds else True
    args.missingDataAsZero = True if args.missingDataAsZero == 'yes' else False
    if args.ignoreForNormalization:
        args.ignoreForNormalization = \
            [x.strip() for x in args.ignoreForNormalization.split(',')]
    else:
        args.ignoreForNormalization = []

    if args.smoothLength and args.smoothLength <= args.binSize:
        print "Warning: the smooth length given ({}) is smaller than the bin "\
            "size ({}).\n\n No smoothing will be "\
            "done".format(args.smoothLength,
                          args.binSize)
        args.smoothLength = None

    return(args)


def getRequiredArgs():
    parser = argparse.ArgumentParser(add_help=False)

    required = parser.add_argument_group('Required arguments')

    # define the arguments
    required.add_argument('--bamfile1', '-b1',
                          metavar='BAM file',
                          help='Sorted BAM file 1. Usually the BAM file '
                          'for the treatment.',
                          required=True)

    required.add_argument('--bamfile2', '-b2',
                          metavar='BAM file',
                          help='Sorted BAM file 2. Usually the BAM '
                          'file for the control.',
                          required=True)

    return parser


def getOptionalArgs():

    parser = argparse.ArgumentParser(add_help=False)
    optional = parser.add_argument_group('Optional arguments')

    optional.add_argument("--help", "-h", action="help",
                          help="show this help message and exit")

    optional.add_argument('--bamIndex1', '-bai1',
                          help='Index for the bam file1 . Default is '
                          'to consider a the path of the bam file adding '
                          'the .bai suffix.',
                          metavar='bam file index')

    optional.add_argument('--bamIndex2', '-bai2',
                          help='Index for the bam file1. Default is to '
                          'consider the path of the bam file adding the .bai '
                          'suffix.',
                          metavar='bam file index')

    optional.add_argument('--scaleFactorsMethod',
                          help='Method to use to scale the samples. '
                          'Default "readCount".',
                          choices=['readCount', 'SES'],
                          default='readCount')

    optional.add_argument('--sampleLength', '-l',
                          help='*Only relevant when SES is chosen for the '
                          'scaleFactorsMethod.* To compute the SES, specify '
                          'the length (in bp) of the regions (see --numberOfSamples) '
                          'which will be randomly sampled to calculate the scaling factors. '
                          'If you do not have a good sequencing depth for '
                          'your samples consider increasing the sampling '
                          'regions\' size to minimize the probability '
                          'that zero-coverage regions are used.',
                          default=1000,
                          type=int)

    optional.add_argument('--scaleFactors',
                          help='Set this parameter to avoid the computation of '
                          'scaleFactors. The format is '
                          'scaleFactor1:scaleFactor2. For example 0.7:1 to '
                          'scale the first BAM file by 0.7 while not scaling '
                          'the second BAM file',
                          default=None,
                          required=False)

    optional.add_argument('--pseudocount',
                          help='small number to avoid x/0. Only useful '
                          'when ratio = log2 or ratio',
                          default=1,
                          type=float,
                          required=False)

    optional.add_argument('--ratio',
                          help='The default is to output the log2ratio between the '
                          'two samples. The reciprocal ratio returns the '
                          'the negative of the inverse of the ratio '
                          'if the ratio is less than 0. The resulting '
                          'values are interpreted as negative fold changes. ',
                          default='log2',
                          choices=['log2', 'ratio', 'subtract', 'add',
                                   'reciprocal_ratio'],
                          required=False)

    optional.add_argument('--normalizeTo1x',
                          help='Report read coverage normalized to 1x '
                          'sequencing depth (also known as Reads Per Genomic '
                          'Content (RPGC)). Sequencing depth is defined as: '
                          '(total number of mapped reads * fragment length) / '
                          'effective genome size.\nThe scaling factor used '
                          'is the inverse of the sequencing depth computed '
                          'for the sample to match the 1x coverage. '
                          'To use this option, the '
                          'effective genome size has to be indicated after the '
                          'command. The effective genome size is the portion '
                          'of the genome that is mappable. Large fractions of '
                          'the genome are stretches of NNNN that should be '
                          'discarded. Also, if repetitive regions were not '
                          'included in the mapping of reads, the effective '
                          'genome size needs to be adjusted accordingly. '
                          'Common values are: mm9: 2150570000, '
                          'hg19:2451960000, dm3:121400000 and ce10:93260000. '
                          'See Table 2 of http://www.plosone.org/article/info:doi/10.1371/journal.pone.0030377 ' 
                          'or http://www.nature.com/nbt/journal/v27/n1/fig_tab/nbt.1518_T1.html '
                          'for several effective genome sizes.',
                          metavar= 'EFFECTIVE GENOME SIZE LENGTH',
                          default=None,
                          type=int,
                          required=False)

    optional.add_argument('--normalizeUsingRPKM',
                          help='(*only when --ratio subtract*) Use RPKM to '
                          'normalize the number of reads per bin. The formula '
                          'is: RPKM (per bin)=#reads per bin / ( # of mapped '
                          'reads (millions) * bin length (KB) ). This is the '
                          'default normalization method.',
                          action='store_true',
                          required=False)

    optional.add_argument('--numberOfSamples', '-n',
                          help='*Only relevant when SES is chosen for the '
                          'scaleFactorsMethod.* Number of samplings taken '
                          'from the genome to compute the scaling factors',
                          default=1e5,
                          type=int)

    optional.add_argument('--missingDataAsZero',
                          default="yes",
                          choices=["yes", "no"],
                          help='Default is "yes". This parameter determines '
                          'if missing data should be treated as zeros. '
                          'If set to "no", missing data will be ignored '
                          'and not included in the output file. Missing '
                          'data is defined as those regions for which '
                          'both BAM files have 0 reads.')

    optional.add_argument('--ignoreForNormalization', '-ignore',
                          help='A list of chromosome names '
                          'separated by comma and limited by quotes, '
                          'containing those '
                          'chromosomes that you want to be excluded '
                          'for computing the normalization. For example, '
                          ' --ignoreForNormalization "chrX, chrM" ')

    optional.add_argument('--centerReads',
                          help='By adding this option reads are centered with '
                          'respect to the fragment length. For paired-end data '
                          'the read is centered at the fragment length defined '
                          'by the two fragment ends. For single-end data, the '
                          'given fragment length is used. This option is '
                          'useful to get a sharper signal around enriched '
                          'regions.',
                          action='store_true')
    return parser

########################################
# MAIN


def main(args):
    """
    The algorithm is composed of two parts.

    1. Using the SES or mapped reads method, appropiate scaling
       factors are determined.

    2. The genome is transversed, scaling the BAM files, and computing
       the log ratio/ratio/difference for bins of fixed width
       given by the user.

    """

    bam1 = bamHandler.openBam(args.bamfile1, args.bamIndex1)
    bam2 = bamHandler.openBam(args.bamfile2, args.bamIndex2)

    # if a list of chromosomes to remove from normalization
    # is given, remove their counts
    if len(args.ignoreForNormalization) > 0:
        import pysam
        # get the number of mapped reads but excluding
        # some undesired chromosomes
        bam1_mapped = sum([int(y[2]) for y in
                           [x.split("\t") for x in
                            pysam.idxstats(bam1.filename)]
                           if y[0] not in args.ignoreForNormalization])

        bam2_mapped = sum([int(y[2]) for y in
                           [x.split("\t") for x in
                            pysam.idxstats(bam2.filename)]
                           if y[0] not in args.ignoreForNormalization])
    else:
        bam1_mapped = bam1.mapped
        bam2_mapped = bam2.mapped

    ################# find scaling factor ##################

    if args.scaleFactors:
        scaleFactors = args.scaleFactors.split(":")
        scaleFactors = [float(scaleFactors[0]), float(scaleFactors[1])]
    else:
        if args.scaleFactorsMethod == 'SES':
            scaleFactorsDict = estimateScaleFactor(
                [bam1.filename, bam2.filename],
                args.sampleLength, args.numberOfSamples,
                args.fragmentLength, 1,
                numberOfProcessors=args.numberOfProcessors,
                verbose=args.verbose,
                chrsToSkip=args.ignoreForNormalization)

            scaleFactors = scaleFactorsDict['size_factors']

            if args.verbose:
                print "Size factors using SES: {}".format(scaleFactors)
                print "%s regions of size %s where used " % \
                    (scaleFactorsDict['sites_sampled'],
                     args.sampleLength)

                print "size factor if the number of mapped " \
                    "reads would have been used:"
                print tuple(
                    float(min(bam1.mapped, bam2.mapped)) / np.array([bam1.mapped, bam2.mapped]))

        elif args.scaleFactorsMethod == 'readCount':
            scaleFactors = \
                float(min(bam1_mapped, bam2_mapped)) / np.array([bam1_mapped, bam2_mapped])
            if args.verbose:
                print "Size factors using total number " \
                    "of mapped reads: {}".format(scaleFactors)

    # in case the substract method is used, the final difference
    # would be normalized according to the given method
    if args.ratio == 'subtract':
        # The next lines identify which of the samples is not scaled down.
        # The normalization using RPKM or normalize to 1x would use
        # as reference such sample. Since the other sample would be
        # scaled to match the un-scaled one, the normalization factor
        # for both samples should be based on the unscaled one.
        # For example, if sample A is unscaled and sample B is scaled by 0.5,
        # then normalizing factor for A to report RPKM read counts
        # is also applied to B.
        if scaleFactors[0] == 1:
            mappedReads = bam1_mapped
        else:
            mappedReads = bam2_mapped

        if args.scaleFactors is None:
            if args.normalizeTo1x:
                current_coverage = \
                                 float(mappedReads * args.fragmentLength) / args.normalizeTo1x
                # the scale factor is 1 / coverage,
                scaleFactor = 1.0 / current_coverage
                scaleFactors = np.array(scaleFactors) * scaleFactor
                if args.verbose:
                    print "Estimated current coverage {}".format(current_coverage)
                    print "Scale factor to convert " \
                          "current coverage to 1: {}".format(scaleFactor)
            else:
                # by default normalize using RPKM
                # the RPKM is:
                # Num reads per tile/(total reads (in millions)*tile length in Kb)
                millionReadsMapped = float(mappedReads)  / 1e6
                tileLengthInKb = float(args.binSize) / 1000
                scaleFactor = 1.0 / (millionReadsMapped * tileLengthInKb)
                scaleFactors = np.array(scaleFactors) * scaleFactor

                if args.verbose:
                    print "scale factor for   "
                    "RPKM is {0}".format(scaleFactor)

    if args.verbose:
        print "Individual scale factors are {0}".format(scaleFactors)

    ################# compute log2ratio ##################

    scaling = float(scaleFactors[0]) / scaleFactors[1]

    print "The scaling factors are:"
    print scaleFactors

    # the getRatio function is called and receives
    # the funcArgs per each tile that is considered
    FUNC = getRatio
    funcArgs = {'missingDataAsZero': args.missingDataAsZero,
                'valueType': args.ratio,
                'scaleFactors': scaleFactors,
                'pseudocount': args.pseudocount
                }

    writeBedGraph.writeBedGraph(
        [bam1.filename, bam2.filename],
        args.outFileName, args.fragmentLength, FUNC,
        funcArgs, tileSize=args.binSize, region=args.region,
        numberOfProcessors=args.numberOfProcessors,
        format=args.outFileFormat,
        zerosToNans=True,
        smoothLength=args.smoothLength,
        extendPairedEnds=args.extendPairedEnds,
        minMappingQuality=args.minMappingQuality,
        ignoreDuplicates=args.ignoreDuplicates,
        centerRead=args.centerReads)

if __name__ == "__main__":
    args = parseArguments()
    main(args)
