#!/usr/bin/env python
'''
Created on Oct 12, 2011

@author: Ying Jin
@contact: yjin@cshl.edu
@status: 
@version: 
'''
# python module
#import os
import sys
#import logging
#import math
import re
import argparse
#import gzip
import time
from math import floor

#SWDeseq module
from TEToolkit.Constants import *
from TEToolkit.PeakDetect import * 
from TEToolkit.IO.ReadInputs import read_opts,read_chrlen_tbl,read_short_reads,read_te_annotation
from TEToolkit.Normalization import normalize
from TEToolkit.PeakModel import PeakModel,NotEnoughPairsException
import subprocess
#from TEToolkit.IO.WriteOutput import plotHeatmap

#main function

class PeakRegion:
    def __init__(self):
        self.cur_start = 0
        self.cur_end = 0
        self.chroms = []
        self.peaklist = {}
        self.cur_index = 0
        
    def reset(self,chr):
        self.cur_end = self.peaklist[chr][0][1]
        self.cur_start = self.peaklist[chr][0][0]
        self.cur_index = 0
        
        
    def insert(self,chr,start,end):
        
        if self.peaklist.has_key(chr) :
                peaks = self.peaklist[chr]
                i = self.cur_index

                len1 = end - start
                len2 = peaks[i][1] - peaks[i][0]
                ovplen = 0
                
                if end < peaks[i][0] : #case1  new entry
                    if i == 0 or peaks[i-1][1] < start :
                       peaks.insert(i,(start,end))
                       self.cur_end = end
                       self.cur_start = start
                       self.cur_index = i
                       print(self.cur_index)
                       return
                    else :
                       self.cur_index = i -1
                       self.cur_end = peaks[i-1][1]
                       self.cur_start = peaks[i-1][0]
                       self.insert(chr, start, end)
                           
                       return
                
                if end < peaks[i][1] and start > peaks[i][0]: #case2 contained in
                    self.cur_end = peaks[i][1]
                    self.cur_start = peaks[i][0]
                    self.cur_index = i
                    return
                
                if end < peaks[i][1] and start < peaks[i][0] : # case 3 see overlap
                    ovplen = end - peaks[i][0]
                    if ovplen >= len1/2 or ovplen >= len2/2 :
                        #take union
                        peaks[i] = (start,peaks[i][1])
                        self.cur_end = peaks[i][1]
                        self.cur_start = peaks[i][0]
                        self.cur_index = i
                        return
                    else :
                        #new entry
                        self.cur_end = peaks[i][0]-1
                        self.cur_start = start
                        peaks.insert(i,(start,peaks[i][0]-1))

                        self.cur_index = i
                        return
                                    
                if end > peaks[i][1] and start < peaks[i][0]: #case 4
                    if i == len(peaks) - 1 or end < peaks[i+1][0]: #contain
                        peaks[i]= (start,end)
                        self.cur_end = end
                        self.cur_start = start
                        self.cur_index = i 
                        return 
                    if end > peaks[i+1][0] : #overlap
                        self.cur_end = peaks[i+1][1]
                        self.cur_start = peaks[i+1][0]
                        self.cur_index = i
                        peaks.pop(i)
                        self.insert(chr,start,end)
                        return 
                if start < peaks[i][1] and end > peaks[i][1] : # case 5
                            ovplen = peaks[i][1] - start
                            if ovplen >= len1/2 or ovplen >= len2/2:  
                            #take union
                                start = peaks[i][0]
                                if i == len(peaks) -1 :
                                    peaks[i] = (start,end)
                                    self.cur_end = end
                                    self.cur_start = start
                                    self.cur_index = i
                                    return
                                else :
                                    peaks.pop(i)
                                    
                                    self.cur_index = i
                                    self.cur_end = peaks[i][1]
                                    self.cur_start = peaks[i][0]
                                    self.insert(chr,start,end)
                                    return 
                            else : #new entry
                                start = peaks[i][1] + 1
                                if i < len(peaks) -1 :
                                    self.cur_index = i + 1
                                    self.cur_end = peaks[i+1][1]
                                    self.cur_start = peaks[i+1][0]
                                    self.insert(chr,start,end)
                                    return
                                else :
                                    peaks.append((start,end))
                                    self.cur_end = end
                                    self.cur_start = start
                                    self.cur_index = i + 1
                                    return
                                                
                if start > peaks[i][1] : #case 6 new entry
                    if i == len(peaks) - 1 :
                        peaks.append((start,end))
                        
                        self.cur_start = start
                        self.cur_end = end
                        self.cur_index = i + 1
                        return
                    else :
                        self.cur_index = i + 1
                        self.cur_end = peaks[i+1][1]
                        self.cur_start = peaks[i+1][0]
                        if end < peaks[i+1][0]:
                            peaks.insert(i+1,(start,end))
                            self.cur_end = end 
                            self.cur_start = start
                            return
                        else :
                            self.insert(chr, start, end)
                            return
                
        else : #new chrom
            self.chroms.append(chr)
            peaks= []
            peaks.append((start,end))
            self.peaklist[chr] = peaks
            self.cur_end = end
            self.cur_start = start
            self.cur_index = 0
            return
    
def detectDFBS(tsamples,csamples,norm,opt ):
    tpeak_tbl = {}
    cpeak_tbl = {}
    
    chroms = {}
    for chr in tsamples[0].peaks.keys() :
        if len(tsamples[0].peaks[chr]) > 0 :
            chroms[chr] = 1
    for chr in csamples[0].peaks.keys() :
        if len(csamples[0].peaks[chr]) > 0  and not chroms.has_key(chr) :
            chroms[chr] = 1
            
    cmb_peaks = PeakRegion()
    
  
    
    for chr in chroms :
        for i in range(len(tsamples)):
            peaks = tsamples[i].peaks[chr]
            for j in range(len(peaks)):
                cmb_peaks.insert(chr,peaks[j][0],peaks[j][1])
            if len(peaks) > 0:
                cmb_peaks.reset(chr)
                
        for i in range(len(csamples)):
            peaks = csamples[i].peaks[chr]
            for j in range(len(peaks)):
                cmb_peaks.insert(chr,peaks[j][0],peaks[j][1])
            if len(peaks) > 0 :
                cmb_peaks.reset(chr)
            
        smpprefix = "T"
        peakregions = cmb_peaks.peaklist[chr]
        for i in range(len(tsamples)) :
                t = tsamples[i]
               # (tags,tmp) = t.get_locations_by_chr(chrom)
                cnts = t.get_counts_by_chr(chr,0)
                smpname = smpprefix + str(i)
                peakcounts = {}
                j = 0   
                for k in range(len(peakregions)):
                    peak = peakregions[k]
                    w = 0
                    entry = chr+":"+str(peak[0])+":"+str(peak[1]) 
                    while j < len(cnts) :
                        pos = floor(cnts[j])
                        cnt = cnts[j] - pos
                        if pos == cnts[j] :
                            pos -= 1
                            cnt = 1
                        
                        if pos < peak[0] :
                            j += 1
                            continue
                        elif pos <= peak[1] and pos >= peak[0] :
                            w += cnt
                            j +=1
                        else :
                            break        
                    peakcounts[entry] = w *norm[0][i]
                    
                tpeak_tbl[smpname] = peakcounts 
                t.clean(chr)    
                
        smpprefix = "C"
        for i in range(len(csamples)) :
                t = csamples[i]
    
                cnts = t.get_counts_by_chr(chr,0)
                smpname = smpprefix + str(i)
                peakcounts = {}
                j = 0   
                for k in range(len(peakregions)):
                    peak = peakregions[k]
                    w = 0
                    entry = chr+":"+str(peak[0])+":"+str(peak[1]) 
                    while j < len(cnts) :
                        pos = floor(cnts[j])
                        cnt = cnts[j] - pos
                        if pos == cnts[j] :
                            pos -= 1
                            cnt = 1
                                                
                        if pos < peak[0] :
                            j += 1
                            continue
                        elif pos <= peak[1] and pos >= peak[0] :
                            w += cnt
                            j +=1
                        else :
                            break        
                    peakcounts[entry] = w *norm[1][i]
                    
                cpeak_tbl[smpname] = peakcounts 
                t.clean(chr)                        

        cmb_peaks = PeakRegion()    
            
            
    return (tpeak_tbl,cpeak_tbl)


def main():
    """Start SWCHIPDIFF......
       parse options......
    """
    
    args=read_opts(prepare_parser())
    
    info = args.info
    #warn = args.warn
    #debug = args.debug
    error = args.error
    #0 output arguments
    info("\n"+args.argtxt+"\n")
    
    #read chrom_size file
    #info("reading chromosome size file %s...\n" %(args.chrom))
    #chrlen_tbl = read_chrlen_tbl(args.chrom,error,info)
    
    # for key in chrlen_tbl.keys():
    #    print("chrom: %s\t size: %d\n" %(key,chrlen_tbl[key]))
    
    if args.TEannotation != None :
        info("reading TE annotation file ... \n")
        TEidx = read_te_annotaion(args.TEannotation)
    
    #read sample files
    info("reading sample files ...\n")
    #tsamples = read_short_reads(args.tfiles,chrlen_tbl,args.parser)
    
    #csamples = read_short_reads(args.cfiles,chrlen_tbl,args.parser)
    tsamples = read_short_reads(args.tfiles,args.parser)
    tinputs = read_short_reads(args.tinputs,args.parser)
       
    csamples = []
    cinputs = []
    if args.cfiles != None :
        csamples = read_short_reads(args.cfiles,args.parser)
        cinputs = read_short_reads(args.cinputs,args.parser) 
        if cinputs[0].libsize() == 0:
            error("Library size of sample %s is 0 !\n " % (args.cinputs[0]))
            system.exit(1)
        for i in range(len(csamples)) :
          if tsamples[i].libsize() == 0 :
            error("Library size of sample %s is 0 !\n " % (args.cfiles[i]))
            system.exit(1)
    
    for i in range(len(tsamples)) :
        if tsamples[i].libsize() == 0 :
            error("Library size of sample %s is 0 !\n " % (args.tfiles[i]))
            system.exit(1)
            
    if tinputs[0].libsize() == 0:
        error("Library size of sample %s is 0 !\n " % (args.tinputs[0]))
        system.exit(1)       
    
    args.tsize = tsamples[0].tsize()
    info("tag size %s \n" %(args.tsize))
    info("normalization ...\n")
    #compute normalization factors
    #scal_factors = normalize(args.norm,tsamples,csamples,chrlen_tbl)
    #tsmp_inputs = tsamples + tinputs
    #csmp_inputs = csamples + cinputs 
    
    
    scal_factors = normalize(args.norm,tsamples,tinputs,csamples,cinputs,args.species[0],args.prj_name)
    
    output_norm(scal_factors,args.prj_name,error)
    
  #  tsmp_inputs = []
  #  csmp_inputs = []
    
    #call binding sites
    info("call binding sites (peaks) .......\n")
    if not args.auto:
#        info("#2 Skipped...")
        shiftsize = int(args.fragsize/2)
        args.scanwindow = args.fragsize * 2
        info("#2 Use %d as shiftsize, %d as fragment length" % (shiftsize,args.fragsize))

    peaks1 = []
    peaks2 = []
    

 
    for i in range(len(tsamples)) :
        
        if i == 0 and args.auto :
            try:
                peakmodel = PeakModel(treatment = tsamples[0],
                                      max_pairnum = MAX_PAIRNUM,
                                      opt = args
                                      )
                info("#2 predicted fragment length is %d bps" % peakmodel.d)
                info("#2.2 Generate R script for model : %s" % (options.modelR))
                model2r_script(peakmodel,args.modelR,options.name)
                args.fragsize = peakmodel.d
                args.scanwindow= 2*args.fragsize
                if args.fragsize <= 2*args.tsize:
                    args.fragsize=args.tsize*2
                    args.scanwindow=2*args.fragsize 
                    warn("#2 Since the d calculated from paired-peaks are smaller than 2*tag length, it may be influenced by unknown sequencing problem. MACS will use %d as shiftsize, %d as fragment length" % (options.shiftsize,options.d))
                    
            except NotEnoughPairsException:
                args.fragsize = 200
                args.scanwindow = args.fragsize * 2
            
                #sys.exit(1)

        sf = [1,1]
        sf[0] = scal_factors[0][i]
        sf[1] = scal_factors[0][len(tsamples)]
        peakdetect1 = PeakDetect(treat = tsamples[i],
                            control = tinputs[0],
                            opt = args, 
                            scal_factors=sf
                            )
        peakdetect1.call_peaks()
        #peaks.append(peakdetect1)
        opfhd = open(args.prj_name+"T"+str(i)+"_peaks.xls",'w')
        s = peakdetect1.toxls()
        
        opfhd.write(peakdetect1.toxls())
        opfhd.close();
        #peakfiles1[i] = args.prj_name+"T"+str(i)+"_peaks.xls"
        peaks1.append(peakdetect1)
        if args.wig :
            wigfile = args.prj_name+"T"+str(i)+"_peaks.wig"
            peakdetect1.towig(wigfile,100)

    tinputs = []
    
    
  
    for i in range(len(csamples)) :

        sf = [1,1]
        sf[0] = scal_factors[1][i]
        sf[1] = scal_factors[1][len(csamples)]
        peakdetect2 = PeakDetect(treat = csamples[i],
                            control = cinputs[0],
                            opt = args, 
                            scal_factors = sf
                            )
        peakdetect2.call_peaks()
        opfhd = open(args.prj_name+"C"+str(i)+"_peaks.xls",'w')
        opfhd.write(peakdetect2.toxls())
        opfhd.close();        
        #peaks.append(peakdetect2)    
        #peakfiles2[i] = args.prj_name+"C"+str(i)+"_peaks.xls"
        peaks2.append( peakdetect2) 
        if args.wig :
            wigfile = args.prj_name+"C"+str(i)+"_peaks.wig"
            peakdetect2.towig(wigfile,100)

    cinputs = []
    #detect differential binding sites
    if args.diffAna :
        if args.cfiles == None :
            error("Cannot do differential analysis without control samples !\n")
            subprocess.call(["rm -f .*.bed  *.r"],shell=True)        
            info("Done \n")
            return
            
        info("detect differentially enriched binding sites ...\n")
        #timestamp = time()
        tmpfile = args.prj_name+"_peaks_counts.rc"
        resfile = args.prj_name+".diff"
        (tpeaks,cpeaks) = detectDFBS(tsamples=peaks1,csamples=peaks2,norm=scal_factors,opt = args)  
        output_peaks_tbl(tpeaks,cpeaks,tmpfile,args) 
        
        cnd = 'per-condition'
        sharingMode = 'maximum'
        fitType = 'local'
        n_tsample = len(tsamples)
        n_csample = len(csamples)
        try:
            if (n_tsample == 1 and n_csample > 1) or (n_csample == 1 and n_tsample > 1):
                cnd = 'pooled'
            
            elif n_tsample == 1 and n_csample == 1 :
                cnd = 'blind'
                sharingMode = 'fit-only'
                fitType = 'local'
            rscript = "diffPeaks.r" 
            diff2r_script(rscript)
           
            subprocess.call(["Rscript", rscript, cnd, fitType,sharingMode, tmpfile, resfile ])
            
            
        except :
            info("differential analysis error!\n")
            sys.exit(1)
            
    subprocess.call(["rm   *.r"],shell=True)        
    info("Done \n")
        
    

def output_res(res,ann,smps,prj):
    
    fname = prj+".png"
   
   # plotHeatmap(res,ann,smps,fname)
    return 

def output_peaks_tbl(t_tbl, c_tbl, fname,opt):
    
    
    try:
        f = open(fname, 'w')
    except IOError:
        error("Cannot create report file %s !\n" % (fname))
        sys.exit(1)
    else:
        cnt_tbl = dict()
        header = "genomeLoc"
        for tsmp in t_tbl.keys():
            cnts = t_tbl[tsmp]
            header +="\t"+tsmp+".T"
            for gene in sorted(cnts.keys()):
                if cnt_tbl.has_key(gene) :
                    cnt_tbl[gene].append(cnts[gene])
                else :
                    val = []
                    val.append(cnts[gene])
                    cnt_tbl[gene] = val
                    
        for csmp in c_tbl.keys():
            cnts = c_tbl[csmp]
            header +="\t"+csmp+".C"
            for gene in sorted(cnts.keys()):
                if cnt_tbl.has_key(gene) :
                    cnt_tbl[gene].append(cnts[gene])
                else :
                    val = []
                    val.append(cnts[gene])
                    cnt_tbl[gene] = val
        #output 
        f.write(header + "\n")
        for gene in cnt_tbl.keys() :
           vals = cnt_tbl[gene]
           vals_str = gene
           max_cnt = 0
           for i in range(len(vals)) :
               if vals[i] > max_cnt :
                   max_cnt = vals[i]
               vals_str +="\t"+str(vals[i])
           if max_cnt >= opt.minread :                       
                   f.write(vals_str + "\n")
            
        f.close()
    
    return
    
def output_norm(sf,name,error):
    fname = name + ".norm"
    try:
        f = open(fname,'w')
    except IOError:
        error("Cannot create report file %s !\n" % (fname))
        sys.exit(1)
    else:
        cnt = 1
        desc = "treat"
        
        for i in range(len(sf[0])) :
          #  desc = "treat" + str(cnt)
           # for i in range(len(b)):
            desc += "\t"+str(round(sf[0][i],2)) 
        f.write(desc+"\n")
            
        desc = "control"
        
        for i in range(len(sf[1])) :
            desc += "\t" + str(round(sf[1][i],2))
        
        f.write(desc+"\n")
        
        f.close()


def diff2r_script(filename):
    rfhd=open(filename,'w')
    rfhd.write("#!/usr/bin/env Rscript\n")

    rfhd.write("args = commandArgs(TRUE)\n")

    rfhd.write("input_file = args[4]\n")
    rfhd.write("output_file = args[5]\n")
    rfhd.write("method = args[1]\n")
    rfhd.write("fitType = args[2]\n")
    rfhd.write("sharingMode = args[3]\n")

    rfhd.write("data = read.table(input_file, header=TRUE, row.names=1)\n")

    rfhd.write("conds = gsub(\"^\\\\S+([TC])$\", \"\\\\1\", colnames(data),perl=TRUE)\n")


    ## Validate condition
    rfhd.write("condition_types = unique(conds)\n")
    rfhd.write("condition_types = condition_types[order(condition_types)]\n")
    rfhd.write("write(paste(condition_types,collapse=\",\"),stderr()); \n")

    rfhd.write("if (length(condition_types) != 2) {\n")
    rfhd.write("stop(paste(\"Input error: expecting exactly two conditions types (e.g. T,T,T,C,C,C) in file\",\n")
    rfhd.write("    input_file, \"( but got:\", paste(condition_types,collapse=\",\"), \")\" ));\n")
    rfhd.write("}\n")

    rfhd.write("suppressPackageStartupMessages(library(DESeq,warn.conflicts = FALSE,quietly=TRUE))\n")

    #if (debug) { write("creating dataset...", stderr()); }
    rfhd.write("d<-round(data,digits=0)\n")
    rfhd.write("cds = newCountDataSet(d,conds)\n")

    rfhd.write("cds = estimateSizeFactors(cds)\n")

    rfhd.write("cds = estimateDispersions(cds,method,fitType=fitType,sharingMode=sharingMode)\n")

#if (debug) { write(paste("Performing Binomial Test on conditions", condition_types[1],"vs",condition_types[2]), stderr()); }
    rfhd.write("output = nbinomTest(cds,condition_types[1], condition_types[2])\n")

## Rename column names from "baseMeanA" and "baseMeanB" to "baseMean_[GROUP1]","baseMean_Group2" -
## to make it easier for people to read the output
    rfhd.write("if (colnames(output)[1] != \"id\" || colnames(output)[3] != \"baseMeanA\" || colnames(output)[4] != \"baseMeanB\") { \n")
    rfhd.write("stop(paste(\"Internal error: output columns from DESeq's nbinomText changed. New columns are:\",\n")
    rfhd.write("paste(colnames(output),collapse=\",\"))); }\n")

    rfhd.write("colnames(output)[1] = \"genomeLoc\"\n")
    rfhd.write("colnames(output)[3] = paste(\"baseMean\",condition_types[1],sep=\"\")\n")
    rfhd.write("colnames(output)[4] = paste(\"baseMean\",condition_types[2],sep=\"\")\n")

    rfhd.write("write.table(output, output_file, quote=FALSE,sep=\"\\t\",row.names=FALSE,col.names=TRUE) \n")
    rfhd.close()



def model2r_script(model,filename,name):
    rfhd = open(filename,"w")
    p = model.plus_line
    m = model.minus_line
    s = model.shifted_line
    d = model.d
    w = len(p)
    norm_p = [0]*w
    norm_m = [0]*w
    norm_s = [0]*w
    sum_p = sum(p)
    sum_m = sum(m)
    sum_s = sum(s)
    for i in range(w):
        norm_p[i] = float(p[i])*100/sum_p
        norm_m[i] = float(m[i])*100/sum_m
        norm_s[i] = float(s[i])*100/sum_s
    rfhd.write("# R script for Peak Model\n")
    rfhd.write("#  -- generated by TEpeaks refered to MACS\n")

    rfhd.write("""p <- c(%s)
m <- c(%s)
s <- c(%s)
x <- seq.int((length(p)-1)/2*-1,(length(p)-1)/2)
pdf("%s_model.pdf",height=6,width=6)
plot(x,p,type="l",col=c("red"),main="Peak Model",xlab="Distance to the middle",ylab="Percentage")
lines(x,m,col=c("blue"))
lines(x,s,col=c("black"))
abline(v=median(x[p==max(p)]),lty=2,col=c("red"))
abline(v=median(x[m==max(m)]),lty=2,col=c("blue"))
abline(v=median(x[s==max(s)]),lty=2,col=c("black"))
legend("topleft",c("forward tags","reverse tags","shifted tags"),lty=c(1,1,1),col=c("red","blue","black"))
legend("right",c("d=%d"),bty="n")
dev.off()
""" % (','.join(map(str,norm_p)),','.join(map(str,norm_m)),','.join(map(str,norm_s)),name,d))
    rfhd.close()
                
def prepare_parser():
    """ inputs(parameters) required/allowed in this pipeline """
    desc = "Identifying differential transcription binding/histone modification sites." 
                                                                                        
                                                                                                   
    exmp = "Example: TEpeaks -t ChIP1.bed ChIP2.bed -c Ctl1.bed Ctl2.bed --tinput input1.bed --cinput input2.bed -s hg " 
    parser = argparse.ArgumentParser(description= desc,epilog=exmp) #'Identifying differential transcription binding/histone modification sites.')
    # parser.add_argument('-t','--treatment', metavar='treatment sample', dest='tfiles',nargs='+', required=True,
    #               help='treatment samples (BED files)')
    #parser.add_argument('-c','--control', metavar='control sample', dest='cfiles',nargs='+', required=True,
    #              help='control samples (BED files)')
    parser.add_argument('-t','--treatment', metavar='treatment sample', dest='tfiles',nargs='+', required=True,
                   help='treatment samples (BAM or BED files)')
    parser.add_argument('-c','--control', metavar='control sample', dest='cfiles',nargs='+',
                   help='control samples (BAM or BED files)')
    parser.add_argument('--tinput',metavar='treatment input',dest='tinputs',nargs=1,required=True,
                   help='input files of treatment')
    parser.add_argument('--cinput',metavar='control input',dest='cinputs',nargs=1,
                   help='input files of control')
    parser.add_argument('-s','--species', metavar='species', dest='species',nargs=1,required=True,
                   help='species : hg (Human hg19),  mm (Mouse mm9), dm (fruit fly dm3)')
    parser.add_argument('-w','--windowsize', metavar='window size', dest='wsize', nargs='?', type=int, 
                   default=WIN_SIZE, help='size of a sliding window. DEFAULT: 1000')
    parser.add_argument('-r','--step', metavar='step', dest='step',nargs='?', type=int,
                   default=STEP, help='size of a step. DEFAULT: 100')
    parser.add_argument('-a','--auto',dest='auto',action="store_true",
                   help='auto detect shiftsize. DEFAULT=False',default=False)
    parser.add_argument('-d','--diff',dest='diffAna',action='store_true',
                   help="require differential analysis",default=False    )
    parser.add_argument('-g','--gap', metavar='gap', dest='gap',nargs='?', type=int,
                   default=GAP, help='max allowed gap. DEFAULT: 1000')
    parser.add_argument('--minread', metavar='min support', dest='minread',nargs='?', type=int,
                   default=CUTOFF, help='minimal reads of a window [0,20]. DEFAULT: 5')
    parser.add_argument('--project', metavar='project name', dest='prj_name',nargs='?', default=PRJ_NAME,
                   help='name of this project. DEFAULT: TEpeak_out')
    parser.add_argument('-f','--fragsize', metavar='fragsize', dest='fragsize',nargs='?', type=int, 
                   default=FRAG_SIZE, help='size of a fragment. DEFAULT: 200')
    parser.add_argument('-p','--padj', metavar='pvalue', dest='pval',nargs='?', type=float,default = P_VAL,
                    help='p-value cutoff. DEFAULT: 1e-5')
   # parser.add_argument('-m','--model', metavar='statistic', dest='stat',nargs='?',default = STAT_METHOD,
   #                 help='statistical model: po (Poisson), nb (Negative binomial), gt (G-test). DEFAULT: gt')
    parser.add_argument('-n','--norm', metavar='normalization', dest='norm',nargs='?',default= NORM_METHOD,
                   help='normalization method : sd (library size), bc (bin correlation). DEFAULT: sd')
    parser.add_argument('--verbose',metavar='verbose',dest='verbose',type=int,nargs='?',default=2,
                         help='Set verbose level. 0: only show critical message, 1: show additional warning message, 2: show process information, 3: show debug messages. DEFAULT:2')
    parser.add_argument('--format',metavar='format',dest='format',type=str,nargs='?',default= "BAM",
                         help='input file format BAM or BED. DEFAULT: BAM')
    parser.add_argument('--lmfold',metavar='lmfold',dest='lmfold',nargs='?',type=int,default=10)
    parser.add_argument('--umfold',metavar='umfold',dest='umfold',nargs='?',type=int, default=30)
    parser.add_argument('--TE',dest='TEannotation',nargs='?', help=('TE annotation file (coordinates)'))
    parser.add_argument('--wig',dest='wig',action="store_true",default=False,help=("output peaks into wiggle format."))
    parser.add_argument('--TEmode', dest='TEmode', nargs='?', default='multi', type=str, help="TE multi-mapper mode: 'multi' or 'sameFamily'")
  
    
    return parser


if __name__ == '__main__':
    try:
        start_time = time.time()
        main()
        end_time = time.time()
        sys.stderr.write("Elapsed time was " + str(round((end_time-start_time)/60,2))+" minutes. \n")
    except KeyboardInterrupt:
        sys.stderr.write("User interrupt !\n")
        sys.exit(0)
