#!/usr/bin/env python
import argparse
import logging
import sys,math,os
import collections
import itertools
import multiprocessing

import pysam
import ngcgh

class Counter:
    mCounts = 0
    def __call__(self, alignment):
        self.mCounts += 1

def readRegions(regionsFileName):
    regionsfile = open(regionsFileName,'r')
    regions=collections.defaultdict(list)
    for line in regionsfile:
        wrds = line.strip().split("\t")
        regions[wrds[0]].append((int(wrds[1]),int(wrds[2])))
    regionsfile.close()
    return regions

def parseregion(rgnString):
    (chrN,boundaries) = rgnString.split(":")
    (startPos,stopPos) = [int(xx) for xx in boundaries.split('-')]
    return {chrN:[(startPos,stopPos)]}

def calculateRegion(args):
    opts = args[0]
    region = args[1]
    logger = multiprocessing.log_to_stderr(opts.loglevel)
    results=[]
    tfile = pysam.Samfile(opts.tumorbam,'rb')
    nfile = pysam.Samfile(opts.normalbam,'rb')
    winsize = opts.windowsize
    lengths=nfile.lengths
    refnames=nfile.references
    chrN=region[0]
    startPos=region[1]
    stopPos=region[2]
    logger.info('Working on %s:%d-%d' % (chrN,startPos,stopPos))
    n=0
    tfileiterator=tfile.fetch(chrN,startPos,stopPos)
    try:
        tread=tfileiterator.next()
    except StopIteration:
        # No reads in the tumor in this region
        return([])
    for nread in nfile.fetch(chrN,startPos,stopPos):
        if(nread.is_duplicate): continue
        if(n==0):
            startloc = nread.pos
        n+=1
        if(n==winsize):
            j=0
            while((tread.pos<nread.pos) & (tread.rname==nread.rname)):
                try:
                    tread=tfileiterator.next()
                    # last tumor read is at position < nread.pos
                    # in this situation, the StopIteration is hit.  
                except StopIteration:
                    break
                if(tread.is_duplicate): continue
                j+=1
            if(j!=0):
                results.append([chrN,startloc,nread.pos,winsize,j,math.log(float(j)/winsize,2)])
            n=0
    tfile.close()
    nfile.close()
    return(results)

def doNormalComparisonCGH(opts,logger):
    results=[]
    logger.info('Starting run')
    tfile = pysam.Samfile(opts.tumorbam,'rb')
    logger.info('tumor file: %s' % opts.tumorbam)
    nfile = pysam.Samfile(opts.normalbam,'rb')
    logger.info('normal file: %s' % opts.normalbam)
    winsize = opts.windowsize
    logger.info('windowsize: %d' % opts.windowsize)
    lengths=nfile.lengths
    refnames=nfile.references
    logger.info('found %d chromosomes' % (len(lengths)))
    regions = {}
    if (opts.regions is not None):
        if os.path.exists(opts.regions): regions = readRegions(opts.regions)
        else: regions = parseregion(opts.regions)
    else:
        regions = dict(zip(refnames,[[(0,xx)] for xx in lengths]))

    outfile=sys.stdout
    if(opts.outfile is not None):
        outfile=open(opts.outfile,'w')
    regionlist = [(opts,(chrN,locs[0][0],locs[0][1])) for chrN,locs in regions.items()]
    pool = multiprocessing.Pool(processes=opts.processes)
    logger.info('here')
    tmpresults = pool.map_async(calculateRegion, regionlist).get(99999999)
    results = list(itertools.chain(*tmpresults))
    # median-center data
    med = median([row[5] for row in results])
    logger.info('Writing output')
    for row in results:
        row[5]=row[5]-med
        # and write the file....
        outfile.write("%s\t%d\t%d\t%d\t%d\t%f\n" % tuple(row))
    outfile.close()

def median(vect):
    vect.sort()
    x = len(vect)
    if(x%2):
        return((vect[x/2]+vect[(x/2)+1])/2.0)
    else:
        return(vect[x/2])
            

def main():
    parser = argparse.ArgumentParser(
	    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-w','--windowsize',dest='windowsize',type=int,default=1000,
                        help='The number of reads captured from the normal sample for calculation of copy number')
    parser.add_argument('-o','--outfile',dest='outfile',
                        help='Output filename, default <stdout>')
    parser.add_argument('-l','--loglevel',dest='loglevel',type=int,default=10,
                        help='Logging Level, 1-15 with 1 being minimal logging and 15 being everything [10]')
    parser.add_argument('-f','--filter',default='0',
                        help='Like samtools, filter out all reads that are included by this flag value, 0 for unset [0]; hex (0x...), decimal, and octal (e.g., 0777) are accepted')
    parser.add_argument('-F','--required',default='0',
                        help='Like samtools, include only reads that are included by this flag value, 0 for unset [0]; hex (0x...), decimal, and octal (e.g., 0777) are accepted')
    parser.add_argument('normalbam',
                        help='The name of the bamfile for the normal comparison')
    parser.add_argument('tumorbam',
                        help='The name of the tumor sample bamfile')
    parser.add_argument('-r','--regions',dest='regions',
                        help='regions to which analysis should be restricted, either a bed file name or a single region in format chrN:XXX-YYY')
    parser.add_argument('-t','--threads',dest='processes',default=1,type=int,
                        help='parallelize over regions (or chromosomes)')

    opts = parser.parse_args()
    logger = multiprocessing.log_to_stderr(opts.loglevel)
    if(opts.filter is not '0'):
        opts.filter = int(opts.filter,0)
    if(opts.required is not '0'):
        opts.required = int(opts.required,0)
    if(opts.normalbam is not None):
        doNormalComparisonCGH(opts,logger)
        exit()

if __name__=='__main__':
    main()
