#!/usr/bin/env python
#-*- coding: utf-8 -*-

import os
import tempfile
import shutil
import time

from bx.seq import twobit
import pysam
import multiprocessing
import numpy as np
import argparse

from scipy.stats import binom

from deeptools.utilities import getGC_content, tbitToBamChrName
from deeptools.countReadsPerBin import getFragmentFromRead
from deeptools import config as cfg
from deeptools import writeBedGraph, parserCommon, mapReduce
from deeptools import utilities

debug = 0
samtools = cfg.config.get('external_tools', 'samtools')
bedgraph_to_bigwig = cfg.config.get('external_tools',
                                    'bedgraph_to_bigwig')
global_vars = dict()


def parseArguments(args=None):
    parentParser = parserCommon.getParentArgParse()
    requiredArgs = getRequiredArgs()
    parser = argparse.ArgumentParser(
        parents=[requiredArgs, parentParser],
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description='Corrects the GC bias using Benjamini\'s method '
        '[Benjamini & Speed (2012). Nucleic acids research, 40(10)]. '
        'The tool computeGC bias needs to be run first.',
        usage='An example usage is:\n %(prog)s '
        '-b file.bam --effectiveGenomeSize 2150570000 -g mm9.2bit '
        '-l 200 --GCbiasFrequenciesFile freq.txt -o gc_corrected.bam '
        '[options]',
        conflict_handler='resolve',
        add_help=False)

    args = parser.parse_args(args)
    if args.correctedFile.name.endswith('bam'):
        
        if not cfg.checkProgram(samtools, 'view',
                                'http://samtools.sourceforge.net/'):
            exit(1)
    if args.correctedFile.name.endswith('bw'):
        if not cfg.checkProgram(bedgraph_to_bigwig, '-h',
                                'http://hgdownload.cse.ucsc.edu/admin/exe/'):
            exit(1)

    return(args)


def getRequiredArgs():
    parser = argparse.ArgumentParser(add_help=False)

    required = parser.add_argument_group('Required arguments')

    # define the arguments
    required.add_argument('--bamfile', '-b',
                          metavar='bam file',
                          help='Sorted Bam file to correct.',
                          required=True)

    required.add_argument('--effectiveGenomeSize',
                          help='The effective genome size is the portion '
                          'of the genome that is mappable. Large fractions of '
                          'the genome are stretches of NNNN that should be '
                          'discarded. Also, if repetitive regions were not '
                          'included in the mapping of reads, the effective '
                          'genome size needs to be adjusted accordingly. '
                          'Common values are: mm9: 2150570000, '
                          'hg19:2451960000, dm3:121400000 and ce10:93260000. '
                          'See Table 2 of '
                          'http://www.plosone.org/article/info:doi/10.1371/journal.pone.0030377 ' 
                          'or http://www.nature.com/nbt/journal/v27/n1/fig_tab/nbt.1518_T1.html '
                          'for several effective genome sizes. This value is '
                          'needed to detect enriched regions that, if not '
                          'discarded can bias the results.',
                          default=None,
                          type=int,
                          required=True)

    required.add_argument('--genome', '-g',
                          help='Genome in two bit format. Most genomes can be '
                          'found here: http://hgdownload.cse.ucsc.edu/gbdb/  '
                          'Search for the .2bit ending. Otherwise, fasta '
                          'files can be converted to 2bit using the UCSC '
                          'programm called faToTwoBit available for different '
                          'plattforms at '
                          'http://hgdownload.cse.ucsc.edu/admin/exe/',
                          metavar='two bit file',
                          required=True)

    required.add_argument('--GCbiasFrequenciesFile', '-freq',
                          help='Indicate the output file from '
                          'computeGCBias containing, '
                          'the observed and expected read frequencies per GC '
                          'content.',
                          type=argparse.FileType('r'),
                          metavar='FILE',
                          required=True)

    output = parser.add_argument_group('Output options')
    output.add_argument('--correctedFile', '-o',
                        help='Name of the corrected file. The ending will '
                        'be used to decide the output file format. The options '
                        'are ".bam", ".bw" for a bigWig file, ".bg" for a '
                        'bedgraph file.',
                        metavar='FILE',
                        type=argparse.FileType('w'),
                        required=True)

    # define the optional arguments
    optional = parser.add_argument_group('Optional arguments')
    optional.add_argument("--help", "-h", action="help",
                          help="show this help message and exit")

    return parser


def getReadGCcontent(tbit, read, fragmentLength, chrNameBit):
    """
    The fragments for forward and reverse reads are defined as follows

       |- read.pos       |- read.aend
    ---+=================>-----------------------+---------    Forward strand

       |-fragStart                               |-fragEnd

    ---+-----------------------<=================+---------    Reverse strand
                               |-read.pos        |-read.aend

       |-----------------------------------------|
                        read.tlen
    """
    fragStart = None
    fragEnd = None

    if read.is_paired and read.is_proper_pair and \
            abs(read.tlen) < 2 * fragmentLength:
        if read.is_reverse and read.tlen < 0:
            fragEnd   = read.aend
            fragStart = read.aend + read.tlen
        elif read.tlen >= read.qlen:
            fragStart = read.pos
            fragEnd   = read.pos + read.tlen

    if not fragStart:
        if read.is_reverse:
            fragEnd    = read.aend
            fragStart  = read.aend  - fragmentLength
        else:
            fragStart  = read.pos
            fragEnd    = fragStart + fragmentLength
    try:
        gc = getGC_content(tbit[chrNameBit].get(fragStart,
                                                fragEnd),
                           as_fraction=True)
    except Exception:
        return None
        """
        print "getReadGCcontent exception."
        print detail
        """
    # match the gc to the given fragmentLength
    gc = int(np.round(gc * fragmentLength))
    return gc


def writeCorrected_wrapper(args):
    return writeCorrected_worker(*args)


def writeCorrected_worker(chrNameBam, chrNameBit, start, end, step):
    r"""writes a bedgraph file containing the GC correction of
    a region from the genome

    >>> test = Tester()
    >>> tempFile = writeCorrected_worker(*test.testWriteCorrectedChunk())
    >>> open(tempFile, 'r').readlines()
    ['chr2L\t200\t225\t31.6\n', 'chr2L\t225\t250\t33.8\n', 'chr2L\t250\t275\t37.9\n', 'chr2L\t275\t300\t40.9\n']
    >>> os.remove(tempFile)
    """
    global R_gc
    fragmentLength = len(R_gc) - 1

    cvg_corr = np.zeros(end - start)

    i = 0

    tbit = twobit.TwoBitFile(open(global_vars['2bit']))
    bam  = pysam.Samfile(global_vars['bam'])
    read_repetitions = 0
    removed_duplicated_reads = 0
    startTime = time.time()

    # caching seems to be faster
    reads = [r for r in bam.fetch(chrNameBam, start, end)]
    bam.close()
    r_index = -1
    for read in reads:
        r_index += 1
        try:
            # calculate GC content of read fragment
            gc = getReadGCcontent(tbit, read, fragmentLength,
                                  chrNameBit)
        except Exception as detail:
            print detail
            """ this exception happens when the end of a
            chromosome is reached """
            continue
        if not gc:
            continue

        # is this read in the same orientation and position as the previous?
        if r_index > 0 and read.pos == reads[r_index - 1].pos and \
                read.is_reverse == reads[r_index - 1].is_reverse \
                and read.pnext == reads[r_index - 1].pnext:
            read_repetitions += 1
            if read_repetitions >= global_vars['max_dup_gc'][gc]:
                removed_duplicated_reads += 1
                continue
        else:
            read_repetitions = 0

        try:
            fragmentStart, fragmentEnd = \
                getFragmentFromRead(read,
                                    fragmentLength,
                                    extendPairedEnds=True)
            vectorStart = max(fragmentStart - start, 0)
            vectorEnd   = min(fragmentEnd   - start, end - start)
        except TypeError:
            # the getFragmentFromRead functions returns None in some cases.
            # Those cases are to be skiped, hence the continue line.
            continue

        cvg_corr[vectorStart:vectorEnd] += float(1) / R_gc[gc]
        i += 1
    if debug:
        endTime = time.time()
        print "{}, processing {} ({:.1f} per sec) "
        "reads @ {}:{}-{}".format(multiprocessing.current_process().name,
                                  i, i / (endTime - startTime),
                                  chrNameBit, start, end)

    if i == 0:
        return None

    _file = open(utilities.getTempFileName(suffix='.bg'), 'w')
    # save in bedgraph format
    for bin in xrange(0, len(cvg_corr), step):
        value = np.mean(cvg_corr[bin:min(bin + step, end)])
        if value > 0:
            writeStart = start + bin
            writeEnd = min(start + bin + step, end)
            _file.write("%s\t%d\t%d\t%.1f\n" % (chrNameBit, writeStart,
                                                writeEnd, value))

    tempFileName = _file.name
    _file.close()
    return tempFileName


def numCopiesOfRead(value):
    """
    Based int he R_gc value, decides
    whether to keep, duplicate, triplicate or delete the read.
    It returns an integer, that tells the number of copies of the read
    that should be keep.
    >>> np.random.seed(1)
    >>> numCopiesOfRead(0.8)
    1
    >>> numCopiesOfRead(2.5)
    2
    >>> numCopiesOfRead(None)
    1
    """
    copies = 1
    if value:
        copies = int(value) + (1 if np.random.rand() < value % 1 else 0)
    return copies


def writeCorrectedSam_wrapper(args):
    return writeCorrectedSam_worker(*args)


def writeCorrectedSam_worker(chrNameBam, chrNameBit, start, end,
                             step=None,
                             tag_but_not_change_number=False,
                             verbose=True):
    r"""
    Writes a SAM file, deleting and adding some reads in order to compensate
    for the GC bias. **This is a stochastic method.**

    First, check if samtools can be executed, otherwise the test will fail
    >>> resp = cfg.checkProgram(samtools, 'view', '')
    >>> np.random.seed(1)
    >>> test = Tester()
    >>> args = test.testWriteCorrectedSam()
    >>> tempFile = writeCorrectedSam_worker(*args, \
    ... tag_but_not_change_number=True, verbose=False)
    >>> res = os.system("{} index {}".format(test.samtools, tempFile))
    >>> bam = pysam.Samfile(tempFile)
    >>> [dict(r.tags)['CP'] for r in bam.fetch(args[0], 200, 250)]
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1]
    >>> res = os.remove(tempFile)
    >>> res = os.remove(tempFile+".bai")
    >>> tempFile = \
    ... writeCorrectedSam_worker(*test.testWriteCorrectedSam_paired(),\
    ... tag_but_not_change_number=True, verbose=False)
    >>> res = os.system("{} index {}".format(test.samtools, tempFile))
    >>> bam = pysam.Samfile(tempFile)
    >>> [dict(r.tags)['CP'] for r in bam.fetch('chr2L', 0, 50)]
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
    >>> res = os.remove(tempFile)
    >>> res = os.remove(tempFile+".bai")
    """
    global R_gc
    fragmentLength = len(R_gc) - 1

    if verbose: print "Sam for %s %s %s " % (chrNameBit, start, end)
    i = 0

    tbit = twobit.TwoBitFile(open(global_vars['2bit']))

    bam = pysam.Samfile(global_vars['bam'])
    tempFileName = utilities.getTempFileName(suffix='.sam')

    outfile = pysam.Samfile(tempFileName, 'wh', template=bam)
    startTime = time.time()
    matePairs = {}
    read_repetitions = 0
    removed_duplicated_reads = 0
    # cache data
    reads = [r for r in bam.fetch(chrNameBam, start, end) if r.pos > start]

    r_index = -1
    for read in reads:
        r_index += 1
        copies = None
        gc     = None

        # check if a mate has already been procesed
        # to apply the same correction
        try:
            copies = matePairs[read.qname]['copies']
            gc = matePairs[read.qname]['gc']
            del(matePairs[read.qname])
        except:
            # this exception happens when a mate is
            # not present. This could
            # happen because of removal of the mate
            # by some filtering
            gc = getReadGCcontent(tbit, read, fragmentLength,
                                  chrNameBit)
            if gc:
                copies = numCopiesOfRead(float(1) / R_gc[gc])
            else:
                copies = 1
        # is this read in the same orientation and position as the previous?
        if gc and r_index > 0 and read.pos == reads[r_index - 1].pos \
                and read.is_reverse == reads[r_index - 1].is_reverse \
                and read.pnext == reads[r_index - 1].pnext:
            read_repetitions += 1
            if read_repetitions >= global_vars['max_dup_gc'][gc]:
                copies = 0  # in other words do not take into account this read
                removed_duplicated_reads += 1
        else:
            read_repetitions = 0

        readName = read.qname
        readTag = read.tags
        if gc:
            GC = int(100 * np.round(float(gc) / fragmentLength,
                                    decimals=2))
            readTag.append(
                ('CO', float(round(float(1) / R_gc[gc], 2))))
            readTag.append(('CP', copies ))
        else:
            GC = -1

        readTag.append(('GC', GC))
        read.tags = readTag

        if read.is_paired and read.is_proper_pair \
                and not read.mate_is_unmapped \
                and not read.is_reverse:
            matePairs[readName] = {'copies': copies,
                                   'gc': gc}

        """
        outfile.write(read)
        """
        if tag_but_not_change_number:
            outfile.write(read)
            continue

        for numCop in range(1, copies + 1):
            # the read has to be renamed such that newly
            # formed pairs will match
            if numCop > 1:
                read.qname  = readName + "_%d" % (numCop)
            outfile.write(read)

        if verbose:
            if i % 500000 == 0 and i > 0:
                endTime = time.time()
                print "{},  processing {} ({:.1f} per sec) reads " \
                    "@ {}:{}-{}".format(multiprocessing.current_process().name,
                                        i, i / (endTime - startTime),
                                        chrNameBit, start, end)
        i += 1

    outfile.close()
    if verbose:
        endTime = time.time()
        print "{},  processing {} ({:.1f} per sec) reads " \
            "@ {}:{}-{}".format(multiprocessing.current_process().name,
                                i, i / (endTime - startTime),
                                chrNameBit, start, end)
        percentage = float(removed_duplicated_reads) * 100 / len(reads) \
            if len(reads) > 0 else 0
        print "duplicated reads removed %d of %d (%.2f) " % \
            (removed_duplicated_reads, len(reads), percentage)

    if verbose: print tempFileName
    # convert sam to bam.
    os.system("{} view -bS {} 2> /dev/null > {}.bam".format(samtools,
                                                            tempFileName,
                                                            tempFileName))

    os.remove(tempFileName)
    return tempFileName + ".bam"


#############
def main(args):

    global F_gc, N_gc, R_gc

    data = np.loadtxt(args.GCbiasFrequenciesFile.name)

    F_gc = data[:, 0]
    N_gc = data[:, 1]
    R_gc = data[:, 2]

    global global_vars
    global_vars = {}
    global_vars['2bit'] = args.genome
    global_vars['bam']  = args.bamfile

    # compute the probability to find more than one read (a redundant read)
    # at a certain position based on the gc of the read fragment
    # the binomial function is used for that
    max_dup_gc = [binom.isf(1e-7, F_gc[x], 1.0 / N_gc[x])
                  if F_gc[x] > 0 and N_gc[x] > 0 else 1
                  for x in range(len(F_gc))]

    global_vars['max_dup_gc'] = max_dup_gc

    bit = twobit.TwoBitFile(open(global_vars['2bit']))
    bam = pysam.Samfile(global_vars['bam'])

    global_vars['genome_size'] = sum([bit[x].size for x in bit.index])
    global_vars['total_reads'] = bam.mapped
    global_vars['reads_per_bp'] = \
        float(global_vars['total_reads']) / args.effectiveGenomeSize

    ################# apply correction ##################
    print "applying correction"
    # divide the genome in fragments containing about 4e5 reads.
    # This amount of reads takes about 20 seconds
    # to process per core (48 cores, 256 Gb memory)
    chunkSize = int(4e5 / global_vars['reads_per_bp'])

    # chromSizes: list of tuples
    chromSizes = [(bam.references[i], bam.lengths[i])
                  for i in range(len(bam.references))]

    regionStart = 0
    if args.region:
        chromSizes, regionStart, regionEnd, chunkSize = \
            mapReduce.getUserRegion(chromSizes, args.region,
                                    max_chunk_size=chunkSize)

    print "genome partition size for multiprocessing: %s" % (chunkSize)
    print "using region {}".format(args.region)

    mp_args = []
    bedGraphStep = args.binSize

    chrNameBitToBam = tbitToBamChrName(bit.index.keys(), bam.references)
    chrNameBamToBit = dict([(v, k) for k, v in chrNameBitToBam.iteritems()])

    c = 1
    for chrom, size in chromSizes:
        start = 0 if regionStart == 0 else regionStart
        for i in xrange(start, size, chunkSize):
            try:
                chrNameBamToBit[chrom]
            except KeyError:
                print "no sequence information for "
                "chromosome {} in 2bit file".format(chrom)
                print "Reads in this chromosome will be skipped"
                continue
            length = min(size, i + chunkSize)
            mp_args.append((chrom, chrNameBamToBit[chrom], i, length,
                            bedGraphStep))
            c += 1

    pool = multiprocessing.Pool(args.numberOfProcessors)

    if args.correctedFile.name.endswith('bam'):
        if len(mp_args) > 1 and args.numberOfProcessors > 1:
            print ("using {} processors for {} "
                   "number of tasks".format(args.numberOfProcessors,
                                            len(mp_args)))

            res = pool.map_async(
                writeCorrectedSam_wrapper, mp_args ).get(9999999)
        else:
            res = map(writeCorrectedSam_wrapper, mp_args)

        if len(res) == 1:
            os.system("cp {} {}".format(res[0], args.correctedFile.name))
        else:
            print "concatenating intemerdiary bams"
            os.system("{} cat -o {} {} ".format(samtools,
                                                args.correctedFile.name,
                                                " ".join(res)))

        print "indexing bam"
        os.system("{} index {} ".format(samtools, args.correctedFile.name))
        for tempFileName in res:
            os.remove(tempFileName)

    if args.correctedFile.name.endswith('bg') or \
            args.correctedFile.name.endswith('bw'):

        _temp_bg_file_name = utilities.getTempFileName(suffix='_all.bg')
        if len(mp_args) > 1 and args.numberOfProcessors > 1:

            res = pool.map_async(writeCorrected_wrapper, mp_args).get(9999999)
        else:
            res = map(writeCorrected_wrapper, mp_args)

        # concatenate intermediary bedgraph files
        _temp_bg_file = open(_temp_bg_file_name, 'w')
        for tempFileName in res:
            if tempFileName:
                # concatenate all intermediate tempfiles into one
                # bedgraph file
                shutil.copyfileobj(open(tempFileName, 'rb'), _temp_bg_file)
                os.remove(tempFileName)
        _temp_bg_file.close()
        args.correctedFile.close()

        if args.correctedFile.name.endswith('bg'):
            shutil.move(_temp_bg_file_name, args.correctedFile.name)

        else:
            chromSizes = [(x, bit[x].size) for x in bit.keys()]
            writeBedGraph.bedGraphToBigWig(chromSizes, _temp_bg_file_name,
                                           args.correctedFile.name)
            os.remove(_temp_bg_file)


class Tester():
    def __init__(self):
        self.root = cfg.config.get('general', 'test_root') + "/test_corrGC/"
        self.tbitFile = self.root + "sequence.2bit"
        self.bamFile  = self.root + "test.bam"
        self.chrNameBam = '2L'
        self.chrNameBit = 'chr2L'
        self.samtools = cfg.config.get('external_tools', 'samtools')
        bam = pysam.Samfile(self.bamFile)
        bit = twobit.TwoBitFile(open(self.tbitFile ))
        global debug
        debug = 0
        global global_vars
        global_vars = {'2bit': self.tbitFile,
                       'bam': self.bamFile,
                       'filter_out': None,
                       'extra_sampling_file': None,
                       'max_reads': 5,
                       'min_reads': 0,
                       'min_reads': 0,
                       'reads_per_bp': 0.3,
                       'total_reads': bam.mapped,
                       'genome_size': sum([bit[x].size for x in bit.index])}

    def testWriteCorrectedChunk(self):
        """ prepare arguments for test
        """
        global R_gc, R_gc_min, R_gc_max
        R_gc = np.loadtxt(self.root + "R_gc_paired.txt")

        global_vars['max_dup_gc'] = np.ones(301)

        start = 200
        end = 300
        bedGraphStep = 25
        return (self.chrNameBam,
                self.chrNameBit, start, end, bedGraphStep )

    def testWriteCorrectedSam(self):
        """ prepare arguments for test
        """
        global R_gc, R_gc_min, R_gc_max
        R_gc = np.loadtxt(self.root + "R_gc_paired.txt")

        global_vars['max_dup_gc'] = np.ones(301)

        start = 200
        end = 250
        return (self.chrNameBam,
                self.chrNameBit, start, end)

    def testWriteCorrectedSam_paired(self):
        """ prepare arguments for test.
        """
        global R_gc, R_gc_min, R_gc_max
        R_gc = np.loadtxt(self.root + "R_gc_paired.txt")
        
        start = 0
        end = 500
        global global_vars
        global_vars['bam']  = self.root + "paired.bam"
        return ('chr2L', 'chr2L', start, end )


if __name__ == "__main__":
    args = parseArguments()
    main(args)
