#!/usr/bin/env python
#-*- coding: utf-8 -*-

import os
import tempfile
import shutil
import time
import subprocess
import sys

from bx.seq import twobit
import pysam
import multiprocessing
import numpy as np
import argparse

from scipy.stats import binom

from deeptools.utilities import getGC_content, tbitToBamChrName
from deeptools.countReadsPerBin import getFragmentFromRead
from deeptools import config as cfg
from deeptools import writeBedGraph, parserCommon, mapReduce
from deeptools import utilities

debug = 0
samtools = cfg.config.get('external_tools', 'samtools')
bedgraph_to_bigwig = cfg.config.get('external_tools',
                                    'bedgraph_to_bigwig')
global_vars = dict()


def parseArguments(args=None):
    parentParser = parserCommon.getParentArgParse()
    requiredArgs = getRequiredArgs()
    parser = argparse.ArgumentParser(
        parents=[requiredArgs, parentParser],
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description='Corrects the GC bias using Benjamini\'s method '
        '[Benjamini & Speed (2012). Nucleic acids research, 40(10)]. '
        'The tool computeGC bias needs to be run first.',
        usage='An example usage is:\n %(prog)s '
        '-b file.bam --effectiveGenomeSize 2150570000 -g mm9.2bit '
        '--GCbiasFrequenciesFile freq.txt -o gc_corrected.bam '
        '[options]',
        conflict_handler='resolve',
        add_help=False)

    args = parser.parse_args(args)
    if args.correctedFile.name.endswith('bam'):
        
        if not cfg.checkProgram(samtools, 'view',
                                'http://samtools.sourceforge.net/'):
            exit(1)
    if args.correctedFile.name.endswith('bw'):
        if not cfg.checkProgram(bedgraph_to_bigwig, '-h',
                                'http://hgdownload.cse.ucsc.edu/admin/exe/'):
            exit(1)

    return(args)


def getRequiredArgs():
    parser = argparse.ArgumentParser(add_help=False)

    required = parser.add_argument_group('Required arguments')

    # define the arguments
    required.add_argument('--bamfile', '-b',
                          metavar='bam file',
                          help='Sorted Bam file to correct.',
                          required=True)

    required.add_argument('--effectiveGenomeSize',
                          help='The effective genome size is the portion '
                          'of the genome that is mappable. Large fractions of '
                          'the genome are stretches of NNNN that should be '
                          'discarded. Also, if repetitive regions were not '
                          'included in the mapping of reads, the effective '
                          'genome size needs to be adjusted accordingly. '
                          'Common values are: mm9: 2150570000, '
                          'hg19:2451960000, dm3:121400000 and ce10:93260000. '
                          'See Table 2 of '
                          'http://www.plosone.org/article/info:doi/10.1371/journal.pone.0030377 ' 
                          'or http://www.nature.com/nbt/journal/v27/n1/fig_tab/nbt.1518_T1.html '
                          'for several effective genome sizes. This value is '
                          'needed to detect enriched regions that, if not '
                          'discarded can bias the results.',
                          default=None,
                          type=int,
                          required=True)

    required.add_argument('--genome', '-g',
                          help='Genome in two bit format. Most genomes can be '
                          'found here: http://hgdownload.cse.ucsc.edu/gbdb/  '
                          'Search for the .2bit ending. Otherwise, fasta '
                          'files can be converted to 2bit using the UCSC '
                          'programm called faToTwoBit available for different '
                          'plattforms at '
                          'http://hgdownload.cse.ucsc.edu/admin/exe/',
                          metavar='two bit file',
                          required=True)

    required.add_argument('--GCbiasFrequenciesFile', '-freq',
                          help='Indicate the output file from '
                          'computeGCBias containing, '
                          'the observed and expected read frequencies per GC '
                          'content.',
                          type=argparse.FileType('r'),
                          metavar='FILE',
                          required=True)

    output = parser.add_argument_group('Output options')
    output.add_argument('--correctedFile', '-o',
                        help='Name of the corrected file. The ending will '
                        'be used to decide the output file format. The options '
                        'are ".bam", ".bw" for a bigWig file, ".bg" for a '
                        'bedgraph file.',
                        metavar='FILE',
                        type=argparse.FileType('w'),
                        required=True)

    # define the optional arguments
    optional = parser.add_argument_group('Optional arguments')
    optional.add_argument("--help", "-h", action="help",
                          help="show this help message and exit")

    return parser


def getReadGCcontent(tbit, read, fragmentLength, chrNameBit):
    """
    The fragments for forward and reverse reads are defined as follows

       |- read.pos       |- read.aend
    ---+=================>-----------------------+---------    Forward strand

       |-fragStart                               |-fragEnd

    ---+-----------------------<=================+---------    Reverse strand
                               |-read.pos        |-read.aend

       |-----------------------------------------|
                        read.tlen
    """
    fragStart = None
    fragEnd = None

    if read.is_paired and read.is_proper_pair and \
            abs(read.tlen) < 2 * fragmentLength:
        if read.is_reverse and read.tlen < 0:
            fragEnd   = read.aend
            fragStart = read.aend + read.tlen
        elif read.tlen >= read.qlen:
            fragStart = read.pos
            fragEnd   = read.pos + read.tlen

    if not fragStart:
        if read.is_reverse:
            fragEnd    = read.aend
            fragStart  = read.aend  - fragmentLength
        else:
            fragStart  = read.pos
            fragEnd    = fragStart + fragmentLength
    try:
        gc = getGC_content(tbit[chrNameBit].get(fragStart,
                                                fragEnd),
                           as_fraction=True)
    except Exception:
        return None
        """
        print "getReadGCcontent exception."
        print detail
        """
    # match the gc to the given fragmentLength
    gc = int(np.round(gc * fragmentLength))
    return gc


def writeCorrected_wrapper(args):
    return writeCorrected_worker(*args)


def writeCorrected_worker(chrNameBam, chrNameBit, start, end, step):
    r"""writes a bedgraph file containing the GC correction of
    a region from the genome

    >>> test = Tester()
    >>> tempFile = writeCorrected_worker(*test.testWriteCorrectedChunk())
    >>> open(tempFile, 'r').readlines()
    ['chr2L\t200\t225\t31.6\n', 'chr2L\t225\t250\t33.8\n', 'chr2L\t250\t275\t37.9\n', 'chr2L\t275\t300\t40.9\n']
    >>> os.remove(tempFile)
    """
    global R_gc
    fragmentLength = len(R_gc) - 1

    cvg_corr = np.zeros(end - start)

    i = 0

    tbit = twobit.TwoBitFile(open(global_vars['2bit']))
    bam  = pysam.Samfile(global_vars['bam'])
    read_repetitions = 0
    removed_duplicated_reads = 0
    startTime = time.time()

    # caching seems to be faster
    # r.flag & 4 == 0 is to skip unmapped
    # reads that nevertheless are asigned
    # to a genomic position
    reads = [r for r in bam.fetch(chrNameBam, start, end)
             if r.flag & 4 == 0]

    bam.close()
    r_index = -1
    for read in reads:
        r_index += 1
        try:
            # calculate GC content of read fragment
            gc = getReadGCcontent(tbit, read, fragmentLength,
                                  chrNameBit)
        except Exception as detail:
            print detail
            """ this exception happens when the end of a
            chromosome is reached """
            continue
        if not gc:
            continue

        # is this read in the same orientation and position as the previous?
        if r_index > 0 and read.pos == reads[r_index - 1].pos and \
                read.is_reverse == reads[r_index - 1].is_reverse \
                and read.pnext == reads[r_index - 1].pnext:
            read_repetitions += 1
            if read_repetitions >= global_vars['max_dup_gc'][gc]:
                removed_duplicated_reads += 1
                continue
        else:
            read_repetitions = 0

        try:
            fragmentStart, fragmentEnd = \
                getFragmentFromRead(read,
                                    fragmentLength,
                                    extendPairedEnds=True)
            vectorStart = max(fragmentStart - start, 0)
            vectorEnd   = min(fragmentEnd   - start, end - start)
        except TypeError:
            # the getFragmentFromRead functions returns None in some cases.
            # Those cases are to be skiped, hence the continue line.
            continue

        cvg_corr[vectorStart:vectorEnd] += float(1) / R_gc[gc]
        i += 1
    if debug:
        endTime = time.time()
        print "{}, processing {} ({:.1f} per sec) "
        "reads @ {}:{}-{}".format(multiprocessing.current_process().name,
                                  i, i / (endTime - startTime),
                                  chrNameBit, start, end)

    if i == 0:
        return None

    _file = open(utilities.getTempFileName(suffix='.bg'), 'w')
    # save in bedgraph format
    for bin in xrange(0, len(cvg_corr), step):
        value = np.mean(cvg_corr[bin:min(bin + step, end)])
        if value > 0:
            writeStart = start + bin
            writeEnd = min(start + bin + step, end)
            _file.write("%s\t%d\t%d\t%.1f\n" % (chrNameBit, writeStart,
                                                writeEnd, value))

    tempFileName = _file.name
    _file.close()
    return tempFileName


def numCopiesOfRead(value):
    """
    Based int he R_gc value, decides
    whether to keep, duplicate, triplicate or delete the read.
    It returns an integer, that tells the number of copies of the read
    that should be keep.
    >>> np.random.seed(1)
    >>> numCopiesOfRead(0.8)
    1
    >>> numCopiesOfRead(2.5)
    2
    >>> numCopiesOfRead(None)
    1
    """
    copies = 1
    if value:
        copies = int(value) + (1 if np.random.rand() < value % 1 else 0)
    return copies


def writeCorrectedSam_wrapper(args):
    return writeCorrectedSam_worker(*args)


def writeCorrectedSam_worker(chrNameBam, chrNameBit, start, end,
                             step=None,
                             tag_but_not_change_number=False,
                             verbose=True):
    r"""
    Writes a SAM file, deleting and adding some reads in order to compensate
    for the GC bias. **This is a stochastic method.**

    First, check if samtools can be executed, otherwise the test will fail
    >>> resp = cfg.checkProgram(samtools, 'view', '')
    >>> np.random.seed(1)
    >>> test = Tester()
    >>> args = test.testWriteCorrectedSam()
    >>> tempFile = writeCorrectedSam_worker(*args, \
    ... tag_but_not_change_number=True, verbose=False)
    >>> res = os.system("{} index {}".format(test.samtools, tempFile))
    >>> bam = pysam.Samfile(tempFile)
    >>> [dict(r.tags)['CP'] for r in bam.fetch(args[0], 200, 250)]
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1]
    >>> res = os.remove(tempFile)
    >>> res = os.remove(tempFile+".bai")
    >>> tempFile = \
    ... writeCorrectedSam_worker(*test.testWriteCorrectedSam_paired(),\
    ... tag_but_not_change_number=True, verbose=False)
    >>> res = os.system("{} index {}".format(test.samtools, tempFile))
    >>> bam = pysam.Samfile(tempFile)
    >>> [dict(r.tags)['CP'] for r in bam.fetch('chr2L', 0, 50)]
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
    >>> res = os.remove(tempFile)
    >>> res = os.remove(tempFile+".bai")
    """
    global R_gc
    fragmentLength = len(R_gc) - 1

    if verbose: print "Sam for %s %s %s " % (chrNameBit, start, end)
    i = 0

    tbit = twobit.TwoBitFile(open(global_vars['2bit']))

    bam = pysam.Samfile(global_vars['bam'])
    tempFileName = utilities.getTempFileName(suffix='.sam')

    outfile = pysam.Samfile(tempFileName, 'wh', template=bam)
    startTime = time.time()
    matePairs = {}
    read_repetitions = 0
    removed_duplicated_reads = 0
    # cache data
    # r.flag & 4 == 0 is to filter unmapped reads that
    # have a genomic position
    reads = [r for r in bam.fetch(chrNameBam, start, end) 
             if r.pos > start and r.flag & 4 == 0]

    r_index = -1
    for read in reads:
        r_index += 1
        copies = None
        gc     = None

        # check if a mate has already been procesed
        # to apply the same correction
        try:
            copies = matePairs[read.qname]['copies']
            gc = matePairs[read.qname]['gc']
            del(matePairs[read.qname])
        except:
            # this exception happens when a mate is
            # not present. This could
            # happen because of removal of the mate
            # by some filtering
            gc = getReadGCcontent(tbit, read, fragmentLength,
                                  chrNameBit)
            if gc:
                copies = numCopiesOfRead(float(1) / R_gc[gc])
            else:
                copies = 1
        # is this read in the same orientation and position as the previous?
        if gc and r_index > 0 and read.pos == reads[r_index - 1].pos \
                and read.is_reverse == reads[r_index - 1].is_reverse \
                and read.pnext == reads[r_index - 1].pnext:
            read_repetitions += 1
            if read_repetitions >= global_vars['max_dup_gc'][gc]:
                copies = 0  # in other words do not take into account this read
                removed_duplicated_reads += 1
        else:
            read_repetitions = 0

        readName = read.qname
        readTag = read.tags
        if gc:
            GC = int(100 * np.round(float(gc) / fragmentLength,
                                    decimals=2))
            readTag.append(
                ('CO', float(round(float(1) / R_gc[gc], 2))))
            readTag.append(('CP', copies ))
        else:
            GC = -1

        readTag.append(('GC', GC))
        read.tags = readTag

        if read.is_paired and read.is_proper_pair \
                and not read.mate_is_unmapped \
                and not read.is_reverse:
            matePairs[readName] = {'copies': copies,
                                   'gc': gc}

        """
        outfile.write(read)
        """
        if tag_but_not_change_number:
            outfile.write(read)
            continue

        for numCop in range(1, copies + 1):
            # the read has to be renamed such that newly
            # formed pairs will match
            if numCop > 1:
                read.qname  = readName + "_%d" % (numCop)
            outfile.write(read)

        if verbose:
            if i % 500000 == 0 and i > 0:
                endTime = time.time()
                print "{},  processing {} ({:.1f} per sec) reads " \
                    "@ {}:{}-{}".format(multiprocessing.current_process().name,
                                        i, i / (endTime - startTime),
                                        chrNameBit, start, end)
        i += 1

    outfile.close()
    if verbose:
        endTime = time.time()
        print "{},  processing {} ({:.1f} per sec) reads " \
            "@ {}:{}-{}".format(multiprocessing.current_process().name,
                                i, i / (endTime - startTime),
                                chrNameBit, start, end)
        percentage = float(removed_duplicated_reads) * 100 / len(reads) \
            if len(reads) > 0 else 0
        print "duplicated reads removed %d of %d (%.2f) " % \
            (removed_duplicated_reads, len(reads), percentage)

    # convert sam to bam.
    command = '{0} view -bS {1} 2> /dev/null > {1}.bam'.format(samtools,
                                                               tempFileName)
    if verbose:
        sys.stderr.write("running {}\n".format(command))

    run_shell_command(command)

    os.remove(tempFileName)
    return tempFileName + ".bam"


def run_shell_command(command):
    """
    Runs the given shell command. Report
    any errors found.
    """
    try:
        subprocess.check_call(command, shell=True)

    except subprocess.CalledProcessError as error:
        sys.stderr.write('Error{}\n'.format(error))
        exit(1)
    except Exception as error:
        sys.stderr.write('Error: {}\n'.format(error))
        exit(1)

#############
def main(args):

    global F_gc, N_gc, R_gc

    data = np.loadtxt(args.GCbiasFrequenciesFile.name)

    F_gc = data[:, 0]
    N_gc = data[:, 1]
    R_gc = data[:, 2]

    global global_vars
    global_vars = {}
    global_vars['2bit'] = args.genome
    global_vars['bam']  = args.bamfile

    # compute the probability to find more than one read (a redundant read)
    # at a certain position based on the gc of the read fragment
    # the binomial function is used for that
    max_dup_gc = [binom.isf(1e-7, F_gc[x], 1.0 / N_gc[x])
                  if F_gc[x] > 0 and N_gc[x] > 0 else 1
                  for x in range(len(F_gc))]

    global_vars['max_dup_gc'] = max_dup_gc

    bit = twobit.TwoBitFile(open(global_vars['2bit']))
    bam = pysam.Samfile(global_vars['bam'])

    global_vars['genome_size'] = sum([bit[x].size for x in bit.index])
    global_vars['total_reads'] = bam.mapped
    global_vars['reads_per_bp'] = \
        float(global_vars['total_reads']) / args.effectiveGenomeSize

    ################# apply correction ##################
    print "applying correction"
    # divide the genome in fragments containing about 4e5 reads.
    # This amount of reads takes about 20 seconds
    # to process per core (48 cores, 256 Gb memory)
    chunkSize = int(4e5 / global_vars['reads_per_bp'])

    # chromSizes: list of tuples
    chromSizes = [(bam.references[i], bam.lengths[i])
                  for i in range(len(bam.references))]

    regionStart = 0
    if args.region:
        chromSizes, regionStart, regionEnd, chunkSize = \
            mapReduce.getUserRegion(chromSizes, args.region,
                                    max_chunk_size=chunkSize)

    print "genome partition size for multiprocessing: %s" % (chunkSize)
    print "using region {}".format(args.region)
    mp_args = []
    bedGraphStep = args.binSize

    chrNameBitToBam = tbitToBamChrName(bit.index.keys(), bam.references)
    chrNameBamToBit = dict([(v, k) for k, v in chrNameBitToBam.iteritems()])
    print chrNameBitToBam, chrNameBamToBit
    c = 1
    for chrom, size in chromSizes:
        start = 0 if regionStart == 0 else regionStart
        for i in xrange(start, size, chunkSize):
            try:
                chrNameBamToBit[chrom]
            except KeyError:
                print "no sequence information for "
                "chromosome {} in 2bit file".format(chrom)
                print "Reads in this chromosome will be skipped"
                continue
            length = min(size, i + chunkSize)
            mp_args.append((chrom, chrNameBamToBit[chrom], i, length,
                            bedGraphStep))
            c += 1

    pool = multiprocessing.Pool(args.numberOfProcessors)

    if args.correctedFile.name.endswith('bam'):
        if len(mp_args) > 1 and args.numberOfProcessors > 1:
            print ("using {} processors for {} "
                   "number of tasks".format(args.numberOfProcessors,
                                            len(mp_args)))

            res = pool.map_async(
                writeCorrectedSam_wrapper, mp_args).get(9999999)
        else:
            res = map(writeCorrectedSam_wrapper, mp_args)

        if len(res) == 1:
            command = "cp {} {}".format(res[0], args.correctedFile.name)
        else:
            print "concatenating intermediary bams"
            command = "{} cat -o {} {} ".format(samtools,
                                                args.correctedFile.name,
                                                " ".join(res))

        run_shell_command(command)

        print "indexing bam"
        run_shell_command("{} index {} ".format(samtools,
                                                args.correctedFile.name))
        for tempFileName in res:
            os.remove(tempFileName)

    if args.correctedFile.name.endswith('bg') or \
            args.correctedFile.name.endswith('bw'):

        _temp_bg_file_name = utilities.getTempFileName(suffix='_all.bg')
        if len(mp_args) > 1 and args.numberOfProcessors > 1:

            res = pool.map_async(writeCorrected_wrapper, mp_args).get(9999999)
        else:
            res = map(writeCorrected_wrapper, mp_args)

        # concatenate intermediary bedgraph files
        _temp_bg_file = open(_temp_bg_file_name, 'w')
        for tempFileName in res:
            if tempFileName:
                # concatenate all intermediate tempfiles into one
                # bedgraph file
                shutil.copyfileobj(open(tempFileName, 'rb'), _temp_bg_file)
                os.remove(tempFileName)
        _temp_bg_file.close()
        args.correctedFile.close()

        if args.correctedFile.name.endswith('bg'):
            shutil.move(_temp_bg_file_name, args.correctedFile.name)

        else:
            chromSizes = [(x, bit[x].size) for x in bit.keys()]
            writeBedGraph.bedGraphToBigWig(chromSizes, _temp_bg_file_name,
                                           args.correctedFile.name)
            os.remove(_temp_bg_file)


class Tester():
    def __init__(self):
        self.root = cfg.config.get('general', 'test_root') + "/test_corrGC/"
        self.tbitFile = self.root + "sequence.2bit"
        self.bamFile  = self.root + "test.bam"
        self.chrNameBam = '2L'
        self.chrNameBit = 'chr2L'
        self.samtools = cfg.config.get('external_tools', 'samtools')
        bam = pysam.Samfile(self.bamFile)
        bit = twobit.TwoBitFile(open(self.tbitFile ))
        global debug
        debug = 0
        global global_vars
        global_vars = {'2bit': self.tbitFile,
                       'bam': self.bamFile,
                       'filter_out': None,
                       'extra_sampling_file': None,
                       'max_reads': 5,
                       'min_reads': 0,
                       'min_reads': 0,
                       'reads_per_bp': 0.3,
                       'total_reads': bam.mapped,
                       'genome_size': sum([bit[x].size for x in bit.index])}

    def testWriteCorrectedChunk(self):
        """ prepare arguments for test
        """
        global R_gc, R_gc_min, R_gc_max
        R_gc = np.loadtxt(self.root + "R_gc_paired.txt")

        global_vars['max_dup_gc'] = np.ones(301)

        start = 200
        end = 300
        bedGraphStep = 25
        return (self.chrNameBam,
                self.chrNameBit, start, end, bedGraphStep )

    def testWriteCorrectedSam(self):
        """ prepare arguments for test
        """
        global R_gc, R_gc_min, R_gc_max
        R_gc = np.loadtxt(self.root + "R_gc_paired.txt")

        global_vars['max_dup_gc'] = np.ones(301)

        start = 200
        end = 250
        return (self.chrNameBam,
                self.chrNameBit, start, end)

    def testWriteCorrectedSam_paired(self):
        """ prepare arguments for test.
        """
        global R_gc, R_gc_min, R_gc_max
        R_gc = np.loadtxt(self.root + "R_gc_paired.txt")
        
        start = 0
        end = 500
        global global_vars
        global_vars['bam']  = self.root + "paired.bam"
        return ('chr2L', 'chr2L', start, end )


if __name__ == "__main__":
    args = parseArguments()
    main(args)
