#!/usr/bin/env python
import argparse
import logging
import sys,math,os

import pysam
import ngcgh

class Counter:
    mCounts = 0
    def __call__(self, alignment):
        self.mCounts += 1

def readRegions(regionsFileName):
    regionsfile = open(regionsFileName,'r')
    regions=collections.defaultdict(list)
    for line in regionsfile:
        wrds = line.strip().split("\t")
        regions[wrds[0]].append((int(wrds[1]),int(wrds[2])))
    regionsfile.close()
    return regions

def parseregion(rgnString):
    (chrN,boundaries) = rgnString.split(":")
    (startPos,stopPos) = [int(xx) for xx in boundaries.split('-')]
    return {chrN:[(startPos,stopPos)]}

def doNormalComparisonCGH(opts):
    results=[]
    tfile = pysam.Samfile(opts.tumorbam,'rb')
    nfile = pysam.Samfile(opts.normalbam,'rb')
    winsize = opts.windowsize
    lengths=nfile.lengths
    refnames=nfile.references
    regions = {}
    if (opts.regions is not None):
        if os.path.exists(opts.regions): regions = readRegions(opts.regions)
        else: regions = parseregion(opts.regions)
    else:
        regions = dict(zip(refnames,[[(0,xx)] for xx in lengths]))

    outfile=sys.stdout
    if(opts.outfile is not None):
        outfile=open(opts.outfile,'w')
    for chrN in regions.iterkeys():
        for (startPos, stopPos) in regions[chrN]:
            n=0
            tfileiterator=tfile.fetch(chrN,startPos,stopPos)
            tread=tfileiterator.next()
            for nread in nfile.fetch(chrN,startPos,stopPos):
                if(nread.is_duplicate): continue
                if(n==0):
                    startloc = nread.pos
                n+=1
                if(n==winsize):
                    j=0
                    while((tread.pos<nread.pos) & (tread.rname==nread.rname)):
                        tread=tfileiterator.next()
                        if(tread.is_duplicate): continue
                        j+=1
                    if(j!=0):
                        results.append([chrN,startloc,nread.pos,winsize,j,math.log(float(j)/winsize,2)])
                    n=0
    # median-center data
    med = median([row[5] for row in results])
    for row in results:
        row[5]=row[5]-med
        # and write the file....
        outfile.write("%s\t%d\t%d\t%d\t%d\t%f\n" % tuple(row))

def median(vect):
    vect.sort()
    x = len(vect)
    if(x%2):
        return((vect[x/2]+vect[(x/2)+1])/2.0)
    else:
        return(vect[x/2])
            

def main():
    logging.basicConfig(level=10)
    logger = logging.getLogger('CGH')
    parser = argparse.ArgumentParser(
	    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-w','--windowsize',dest='windowsize',type=int,default=1000,
                        help='The number of reads captured from the normal sample for calculation of copy number')
    parser.add_argument('-o','--outfile',dest='outfile',
                        help='Output filename, default <stdout>')
    parser.add_argument('-l','--loglevel',dest='loglevel',type=int,
                        help='Logging Level, 1-15 with 1 being minimal logging and 15 being everything [10]')
    parser.add_argument('normalbam',
                        help='The name of the bamfile for the normal comparison')
    parser.add_argument('tumorbam',
                        help='The name of the tumor sample bamfile')
    parser.add_argument('-r','--regions',dest='regions',
                        help='regions to which analysis should be restricted, either a bed file name or a single region in format chrN:XXX-YYY')
    opts = parser.parse_args()
    if(opts.loglevel is not None):
        logger.setLevel(opts.loglevel)
    if(opts.normalbam is not None):
        doNormalComparisonCGH(opts)
        exit()
    
main()
