#!/usr/bin/env python2.6

'''
Created on Jan 27, 2014

@author: Ying Jin
@contact: yjin@cshl.edu
@author: Oliver Tam
@contact tam@cshl.edu
@status: 
@version: 1.0.0
'''
# python module
import sys
import logging
import re
import copy
import math
import operator
import argparse
import optparse, itertools, warnings, traceback, os.path
import HTSeq
import subprocess
from time import time
from TEToolkit.IO.ReadInputs import read_opts2, read_features
from TEToolkit.TEindex import *
from TEToolkit.EMAlgorithm import *

### Define parameters for program ###
def prepare_parser():
    desc = "Identifying differential transcription of gene and transposable elements." 

    exmp = "Example: TEtranscripts --format BAM --mode uniq -t RNAseq1.bam RNAseq2.bam -c CtlRNAseq1.bam CtlRNAseq.bam  "
    
    parser = argparse.ArgumentParser(description=desc, epilog=exmp) #'Identifying differential transcription binding/histone modification sites.')

    parser.add_argument('-t','--treatment', metavar='treatment sample', dest='tfiles',nargs='+', required=True,
                        help='Sample files in group 1 (e.g. treatment/mutant)')
    parser.add_argument('-c','--control', metavar='control sample', dest='cfiles',nargs='+', required=True,
                        help='Sample files in group 2 (e.g. control/wildtype)')
    parser.add_argument('--GTF', metavar='genic-GTF-file', dest='gtffile', type=str, required=True,
                        help='GTF file for gene annotations')
    parser.add_argument('--TE', metavar='TE-GTF-file', dest='tefile', type=str, required=True,
                        help='GTF file for transposable element annotations')
    parser.add_argument('--format', metavar='input file format', dest='format', type=str, nargs='?', default='BAM', choices=['BAM','SAM'],
                        help='Input file format: BAM or SAM. DEFAULT: BAM')
    parser.add_argument('--stranded', metavar='option', dest='stranded', nargs='?', type=str, default="yes", choices=['yes','no','reverse'],
                        help='Is this a stranded library? (yes, no, or reverse). DEFAULT: yes.')
    parser.add_argument('--mode', metavar='TE counting mode', dest='te_mode', nargs='?', type=str, const="multi", default='uniq', choices=['uniq','sameEle','multi'],
                        help='How to count TE: uniq (unique mappers only), sameEle (assign to dominant TE), or multi (distribute among all alignments).\
                        DEFAULT: uniq')
    parser.add_argument('--project', metavar='name', dest='prj_name', nargs='?', default='TEtranscripts_out',
                        help='Name of this project. DEFAULT: TEtranscripts_out')
    parser.add_argument('-p', '--padj', metavar='pvalue', dest='pval', nargs='?', type=float, const=0.1, default=0.05,
                        help='FDR cutoff for significance. DEFAULT: 0.05')
    parser.add_argument('-f', '--foldchange', metavar='foldchange', dest='fc', nargs='?', type=float, const=2.0, default=1.0,
                        help='Fold-change ratio (absolute) cutoff for differential expression. DEFAULT: 1')
    parser.add_argument('-n', '--norm', metavar='normalization', dest='norm', nargs='?', default='rpm', choices=['rpm','quant'],
                        help='Normalization method : rpm (reads per million mapped), quant (quantile normalization). DEFAULT: rpm')
    parser.add_argument('--no-sort', dest = 'nosort', action="store_true",
                        help='Input file is not sorted by chromosome position.')
    parser.add_argument('-i', '--iteration', metavar='iteration', dest='numItr', nargs='?', type=int,  default=0,
                        help='number of iteration to run the optimization. DEFAULT: 0')
    parser.add_argument('--verbose', metavar='verbose', dest='verbose', type=int, nargs='?', default=2,
                        help='Set verbose level. 0: only show critical message, 1: show additional warning message, 2: show process information, 3: show debug messages. DEFAULT:2')

    return parser



class UnknownChrom(Exception):
    pass

# Reading files containing alignments (BAM of SAM)
def count_reads(samples, format, features, teIdx,counts,stranded, te_mode, nosort, prj_name,numItr):
    
    cnt_tbl = {}
    rpm_val = []

    warnings.showwarning = my_showwarning
    try:
       # Try to open files, which causes the program to fail early if they are not there
        for filename in samples :
            if filename != "-":
                open(filename).close()

    except:
        sys.stderr.write("Error: Cannot open %s\n" % (filename))
        sys.stderr.write( "[Exception type: %s, raised in %s:%d]\n" % 
                          ( sys.exc_info()[1].__class__.__name__, 
                            os.path.basename(traceback.extract_tb( sys.exc_info()[2] )[-1][0]), 
                            traceback.extract_tb( sys.exc_info()[2] )[-1][1] ) )
        sys.exit(1)

    try:
        # Processing alignment files for annotation assignment
        
        for filename in samples :
            num_reads = 0
            if format == "BAM" :
                (num_reads,te_instance_counts) = count_SAMformat_v2(filename,format,features,teIdx,counts,stranded,te_mode, nosort, prj_name,numItr)
                    
                '''
                if not nosort:
                    try:
                        subprocess.call('rm TEtranscripts_tmp.bam')
                    except:
                        sys.stderr.write("Failed to remove TEtranscripts_tmp.bam")
                '''
            else : # SAM format
                    (num_reads,te_instance_counts) = count_SAMformat_v2(filename,format,features,teIdx,counts,stranded,te_mode,"False", prj_name,numItr)

            librpm = float(num_reads) / 1000000
            rpm_val.append(librpm)
            
            te_ele_counts = groupByEle(te_instance_counts,teIdx)
            # Store count results of the file in 
            cnt_tbl[filename] = dict(counts.items() + te_ele_counts.items()) #copy()
                
            # Reset counts for gene/TE features
            counts = dict(zip(counts.keys(),[0]*len(counts)))
            
    except:
        sys.stderr.write("Error: %s\n" % str(sys.exc_info()[1]))
        sys.stderr.write( "[Exception type: %s, raised in %s:%d]\n" % 
                          ( sys.exc_info()[1].__class__.__name__, 
                            os.path.basename(traceback.extract_tb( sys.exc_info()[2] )[-1][0]), 
                            traceback.extract_tb( sys.exc_info()[2] )[-1][1] ) )
        sys.exit(1)
       
    return (cnt_tbl, rpm_val)

#group by TE element

def groupByEle(te_counts,teIdx) :
    te_ele_counts = {}
    for i in range(len(te_counts)) :
        ele_name = teIdx.getEleName(i)
        if ele_name is None: 
           sys.stderr.write("TE out of index boundary!\n")
           sys.exit(1)
        if ele_name in te_ele_counts :
            te_ele_counts[ele_name] += te_counts[i]
        else :
            te_ele_counts[ele_name] = te_counts[i]

    return te_ele_counts
 
#TE annotation
def TE_annotation(iv_seq,teIdx):
    #if len(iv_seq) > 2 :
    #    raise
    TEs = []
    for iv in iv_seq :
        chr = iv.chrom
        #start = iv.start - 100
        #end = iv.start + 100
        start = iv.start
        end = iv.end
        strand = iv.strand
     #   sys.stderr.write(chr+"\t"+str(iv.start)+"\t"+str(end)+"\n")
        name_idx_list  = teIdx.findOvpTE(chr,start,end)
        if name_idx_list is not None :
            for t in name_idx_list :
                if t not in TEs :
                    TEs.append(t)
                        
    return TEs
        
#read assignment
def read_assignment(pe_mode,multi_reads,features,teIdx,stranded) :
                annot_gene = []
                annot_TE = []
                
                for aln in multi_reads :
                    iv_seq = []
                    inv_iv_seq = []
                    if not pe_mode : #single end read	                	 
                		iv_seq.append(aln.iv)
                    else : #paired end read
                         if aln[0] is not None and aln[0].aligned:
                			if stranded != "reverse":
                				iv_seq = (co.ref_iv for co in aln[0].cigar if co.type == "M" and co.size > 0)
                			else:
                				iv_seq = (invert_strand(co.ref_iv) for co in aln[0].cigar if co.type == "M" and co.size > 0)
                         if aln[1] is not None and aln[1].aligned:
                		 	if stranded != "reverse":
                		 		iv_seq = itertools.chain(iv_seq, (invert_strand(co.ref_iv) for co in aln[1].cigar if co.type == "M" and co.size > 0))
                		 	else:
                		 		iv_seq = itertools.chain(iv_seq, (co.ref_iv for co in aln[1].cigar if co.type == "M" and co.size > 0))
               	       #annotation
                    try:
                             try:
                                 iv_seq = list(iv_seq)
                                 TEs = TE_annotation(iv_seq,teIdx)
                                 if len(TEs) > 0   :
                                     annot_TE.append(TEs)
      #                           for tid in TEs :
       #                               sys.stderr.write(aln.get_sam_line()+"\t"+teIdx.getFullName(tid)+"\n")
                             except:
                                 #sys.stderr.write("Error occurred when assigning read (%s) to genic annotations in file %s.\n" % (curr_read_name, filename))
                                 raise
                             genes = []
                             fs = None
                             for iv in iv_seq:
                                 try:
                                     for iv2, fs2 in features[ iv ].steps():
                                         if len(fs2) > 0:
                                             if fs is None:
                                                 fs = fs2.copy()
                                             else:
                                                 fs = fs.intersection(fs2)
                                     if fs is not None and len(fs) > 0:
                                        counter = 0
                                        while(counter < len(fs)):
			                            	annotation = str(next(iter(fs)))
			                                genes.append(annotation)
			                                counter += 1
                                 except:
                                     #sys.stderr.write("Error occurred when assigning read (%s) to genic annotations in file %s.\n" % (curr_read_name, filename))
                                     raise
                             if len(genes) > 0 :
			       	annot_gene.append(genes)
                    except:
                             sys.stderr.write("Error occurred during read assignments\n")
                             raise
			
                return (annot_gene,annot_TE)
				

def readInAlignment(filename, format, nosort,prj_name):
    try:
        if format == "BAM" :
            if nosort  :
            
                samtools_in = subprocess.Popen(["samtools", "view", filename], stdout=subprocess.PIPE)
                alignments = HTSeq.SAM_Reader(samtools_in.stdout)
                first_read = iter(alignments).next()
                pe_mode = first_read.paired_end
     #           alignments = HTSeq.SAM_Reader(filename)
                
            else :
                psort = subprocess.Popen(["samtools", "sort", "-n", "-o", filename, "%s_%s_tmp" % (filename, prj_name)], stdout=subprocess.PIPE)
                samtools_in = subprocess.Popen(["samtools", "view", "-"], stdin=psort.stdout, stdout=subprocess.PIPE)
                psort.stdout.close()  # Allow p to receive a SIGPIPE if psort exits.
#            samtools_in = subprocess.Popen(["samtools","view",filename], stdout=subprocess.PIPE)            
                alignments = HTSeq.SAM_Reader(samtools_in.stdout)
                first_read = iter(alignments).next()
                pe_mode = first_read.paired_end
            #alignments = HTSeq.SAM_Reader(samtools_in.stdout)
        else :
                alignments = HTSeq.SAM_Reader(filename)
                first_read = iter(alignments).next()
                pe_mode = first_read.paired_end

        
    except:
        sys.stderr.write("Error occured when reading first line of sample file %s.\n" % filename)
        raise
    return (alignments,pe_mode,first_read)

#get read name of current alignment
def getCurReadName(a,pe_mode):
    # Single-end reads
    if not pe_mode:
        if not a.aligned:# notaligned += 1
            return ""
        return (a.read.name,len(a.read.seq))
    # Paired-end reads
    else:
                if ((a[0] is None) or not (a[0].aligned)) and ((a[1] is None) or not a[1].aligned):
                    #notaligned += 1
                    return ("",0)                    
                    # get the read name  
                if a[0] is not None and a[0].aligned:
                       return (a[0].read.name,len(a[0].read.seq))
                                
                if a[1] is not None and a[1].aligned:
                       return (a[1].read.name,len(a[1].read.seq))
                       
def count_SAMformat_v2(filename, format, features, teIdx,gene_counts, stranded, te_mode, nosort, prj_name,numItr):

    
    (alignments,pe_mode,first_read) = readInAlignment(filename, format, nosort,prj_name)
    

    sys.stderr.write("Parsing sample file %s\n" % filename)
    
    empty = 0
    nonunique = 0
    i = 0
    pre_read_name =''
    num_reads = 0
    multi_reads = []
    alignments_per_read = []
    leftOver_gene = []
    leftOver_te = []
    avgReadLength =  0
    tmp = []
    te_counts = [0.0] * teIdx.numTEs()
    te_multi_counts = [0.0]*len(te_counts)

    try:
        if pe_mode:
            alignments_pe_file = alignments
            second_read = iter(alignments).next()
            algn_first_pair = (first_read,second_read)
            (curr_read_name,curr_read_len) = getCurReadName(algn_first_pair,pe_mode)       
            pre_read_name = curr_read_name
            alignments_per_read.append(algn_first_pair)                    
            num_reads += 1
            alignments = HTSeq.pair_SAM_alignments(alignments)

        for a in alignments:
            i += 1
            (curr_read_name,curr_read_len) = getCurReadName(a,pe_mode)       
            
            if len(tmp) < 10000 :
                tmp.append(curr_read_len)
                
            #first read
            if pre_read_name == '' :
                    pre_read_name = curr_read_name
            #mult-reads        
            if pre_read_name == curr_read_name and curr_read_name != '':
                    alignments_per_read.append(a)                    
            #new read
            else :
                num_reads +=1
                
                if len(alignments_per_read) > 1 :
                    nonunique += 1          
                    if te_mode == 'uniq' :
                        empty += 1
                        alignments_per_read = []
                        pre_read_name = curr_read_name
                        alignments_per_read.append(a)               		            
                        continue            	
                                                                
                #weighting
                try:#read assignment
             #       sys.stderr.write(pre_read_name+"\n")
                    (annot_gene,annot_TE) = read_assignment(pe_mode,alignments_per_read, features, teIdx,stranded)
                    if len(annot_TE) < len(annot_gene) :
                        no_annot1 = parse_annotations_gene(annot_gene,gene_counts,leftOver_gene)
                        if no_annot1 :
                            no_annot2 = parse_annotations_TE(multi_reads,annot_TE, te_counts, te_multi_counts, teIdx,te_mode,leftOver_te)
                            if no_annot2 :
                                empty += 1
                    else :
                        no_annot1 = parse_annotations_TE(multi_reads,annot_TE, te_counts, te_multi_counts, teIdx,te_mode,leftOver_te)
                        if no_annot1  :
                            no_annot2 = parse_annotations_gene(annot_gene,gene_counts,leftOver_gene)
                            if no_annot2 :
                                empty += 1
                   
                except:
                    sys.stderr.write("Error occurred when processing annotations of %s in file %s.\n" % (pre_read_name, filename))
                    raise
                
                if i % 1000000 == 0 :
                    sys.stderr.write("%d %s processed.\n" % (i, "alignments " ))

                alignments_per_read = []
                pre_read_name = curr_read_name
                alignments_per_read.append(a)                                   

        # the last read
        try:
            if pre_read_name != '' :
                num_reads += 1
                if (len(alignments_per_read) > 1 and 'uniq' not in te_mode) or len(alignments_per_read) == 1:
                    nonunique += 1
                    (annot_gene,annot_TE) = read_assignment(pe_mode,alignments_per_read, features, teIdx,stranded)
                    if len(annot_TE) < len(annot_gene) :
                        no_annot1 = parse_annotations_gene(annot_gene,gene_counts,leftOver_gene)
                        if no_annot1 :
                            no_annot2 = parse_annotations_TE(multi_reads,annot_TE, te_counts, te_multi_counts, teIdx,te_mode,leftOver_te)
                            if no_annot2 :
                                empty += 1
                    else :
                        no_annot1 = parse_annotations_TE(multi_reads,annot_TE, te_counts, te_multi_counts, teIdx,te_mode,leftOver_te)
                        if no_annot1  :
                            no_annot2 = parse_annotations_gene(annot_gene,gene_counts,leftOver_gene)
                            if no_annot2 :
                                empty += 1
            #caused by overlapping genes or TEs
            if len(leftOver_gene) > 0 and te_mode == 'uniq':
                resolve_annotation_ambiguity(gene_counts,leftOver_gene)           
            if len(leftOver_te) > 0 and te_mode == 'uniq':
                resolve_annotation_ambiguity(te_counts,leftOver_te)
            if te_mode == 'uniq' :
                  empty += len(leftOver_gene) + len(leftOver_te)     		
            ss = sum(te_counts)
            sys.stderr.write("uniq te counts = %s \n" % (str(ss)))
            te_tmp_counts = [0]*len(te_counts)
            if numItr > 0 :
              try :
                ''' iterative optimization on TE reads '''
                sys.stderr.write(".......start iterative optimization ..........\n")
                avgReadLength = int(sum(tmp)/len(tmp))
                sys.stderr.write("...average read length %d \n" % (avgReadLength))   
                if len(multi_reads) > 0 :
                    new_te_multi_counts = EMestimate(teIdx,multi_reads,te_tmp_counts,te_multi_counts,numItr,avgReadLength)
                    
              except :
                sys.stderr.write("Error in optimization\n")
                raise
              te_counts = map(operator.add,te_counts,new_te_multi_counts)
            else :
              te_counts = map(operator.add,te_counts,te_multi_counts)
            
            
        except:
            sys.stderr.write("Error occurred when assigning read gene/TE annotations in file %s.\n" % (filename))
            raise
        ss = sum(te_counts)

        sys.stderr.write("te_counts total %s\n" % (ss))
        sys.stderr.write("\nIn library %s:\n" % (filename))
        sys.stderr.write("Total mapped reads = %s\n" % ( str(num_reads) ) )
        sys.stderr.write("Total non-uniquely mapped reads = %s\n" % (str(int(nonunique))))
        sys.stderr.write("Total unannotated reads = %s\n\n" %(str(int(empty))))

    except:
        if not pe_mode:
            sys.stderr.write("Error occured in %s.\n" % alignments.get_line_number_string())
        else:
            sys.stderr.write("Error occured in %s.\n" % alignments_pe_file.get_line_number_string())
        raise

    return (num_reads,te_counts)
				

# Invert strand of alignment
def invert_strand(iv):
    iv2 = iv.copy()
    if iv2.strand == "+":
           iv2.strand = "-"
    elif iv2.strand == "-":
           iv2.strand = "+"
    else:
           raise ValueError, "Illegal strand"
    return iv2

    
def parse_annotations_gene(annot_gene, gene_counts,leftOver_gene):

        if len(annot_gene) > 1 :
            leftOver_gene.append((annot_gene,1.0))
        elif len(annot_gene) == 1 :
            genes = annot_gene[0]
            if len(genes) == 1 :
                gene_counts[genes[0]] += 1
            else :
            	if genes[0] == genes[1] :
            		gene_counts[genes[0]] += 1
            	else :
            		gene_counts[genes[0]] += 0.5
            		gene_counts[genes[1]] += 0.5
        else :
            return True
        
        return False

# Assign ambiguous genic reads mapped to multiple locations
def resolve_annotation_ambiguity(counts, leftOvers ) :
  
    for (annlist,w) in leftOvers :
    	readslist = {}
    	total = 0.0
    	
        for ann in annlist :
            for a in ann:
            	if a not in readslist :
            		readslist[a] = counts[a]
            		total += counts[a]
    
	    if total > 0.0 :
	        for a in readslist :
	            v = w * readslist[a] / total
	            counts[a] += v
	    else :
	        for a in readslist :
	            counts[a] = w/len(readslist)

               
def parse_annotations_TE(multi_reads,annot_TE,uniq_counts,multi_counts,te_features, te_mode,leftOver_list):
    if len(annot_TE) == 0 :
        return True
    
    if len(annot_TE) == 1 :
        if len(annot_TE[0]) == 1 :
            uniq_counts[annot_TE[0][0]] += 1    
        else :
            leftOver_list.append((annot_TE,1.0))
        return False
    
    if te_mode == 'multi':
            multi_algn = []
            for i in range(len(annot_TE)):
                for te in annot_TE[i] :
                    multi_counts[te] += 1.0 / (len(annot_TE) * len(annot_TE[i]))
                    multi_algn.append(te)
            
            multi_reads.append(multi_algn)
    		
    if te_mode == 'uniq':
        if len(annot_TE[0]) == 1 :
            uniq_counts[annot_TE[0][0]] += 1    
    	else :
    		leftOver_list.append((annot_TE,1.0))
        
    if 'sameEle' in te_mode:    # Assign to top two/one
        elements = {}
        for TEs in annot_TE:
            for t in TEs :
                    t_ele = te_features.getEleName(t)
                    if t_ele is not None :
                        if t_ele in elements :
                            elements[t_ele].append(t)
                        else :                            
                            elements[t_ele] = [t] 
                    else :
                        sys.stderr.write("this TE does not exist. \n")
                        raise 
      
        (top_fam1, top_fam2, cnt1, cnt2) = majority(elements)
        multi_algn = []
        if cnt2 == 0 or cnt1 / cnt2 > 2.5 :
            TEs = elements[top_fam1]
            w = 1.0/len(TEs)
            for t in TEs :
                multi_counts[t] += w
            multi_reads.append(TEs)
                
        else :
            multi_algn = []
            TEs1 = elements[top_fam1]
            TEs2 = elements[top_fam2]
            for t in TEs1 :
                multi_counts[t] += (1.0 * (cnt1/(cnt1 + cnt2)))/len(TEs1)
                multi_algn.append(t)
            for t in TEs2 :
                multi_counts[t] += (1.0 * (cnt2/(cnt1 + cnt2)))/len(TEs2)
                multi_algn.append(t)
            multi_reads.append(multi_algn)
            
    return False

def majority(te_counts) :
    top_cnt1 = 0
    top_cnt2 = 0 
    top_fam1 = ""
    top_fam2 = ""
    
    for feature in te_counts :
        feat_cnt = len(te_counts[feature])
        if feat_cnt > top_cnt1 :
            top_fam1 = feature
            top_cnt1 = feat_cnt
            
        elif feat_cnt > top_cnt2 :
            top_fam2 = feature
            top_cnt2 = feat_cnt
    
    return (top_fam1, top_fam2, top_cnt1, top_cnt2)

def my_showwarning(message, category, filename, lineno=None, line=None):
    sys.stderr.write("Warning: %s\n" % message)


def output_res(res, ann, smps, prj):
    
    fname = prj+".png"
   
    plotHeatmap(res, ann, smps, fname)
    return 

def output_count_tbl(tfiles,cfiles,t_tbl, c_tbl, fname):
       
    try:
        f = open(fname, 'w')
    except IOError:
        error("Cannot create report file %s !\n" % (fname))
        sys.exit(1)
    else:
        cnt_tbl = {}
        header = "gene/TE"
        for i in range(len(tfiles)):
            tsmp = tfiles[i]
        #for tsmp in t_tbl.keys():
            cnts = t_tbl[tsmp]
            header +="\t"+tsmp+".T"
            for gene in sorted(cnts.keys()):
                if cnt_tbl.has_key(gene) :
                    cnt_tbl[gene].append(int(cnts[gene]))
                else :
                    val = []
                    val.append(int(cnts[gene]))
                    cnt_tbl[gene] = val
        for j in range(len(cfiles)) :
            csmp = cfiles[j]            
        #for csmp in c_tbl.keys():
            cnts = c_tbl[csmp]
            header +="\t"+csmp+".C"
            for gene in sorted(cnts.keys()):
                if cnt_tbl.has_key(gene) :
                    cnt_tbl[gene].append(int(cnts[gene]))
                else :
                    val = []
                    val.append(int(cnts[gene]))
                    cnt_tbl[gene] = val
        #output 
        f.write(header + "\n")
        for gene in cnt_tbl.keys() :
           vals = cnt_tbl[gene]
           vals_str = gene
           for i in range(len(vals)) :
              vals_str +="\t"+str(vals[i])                       
           f.write(vals_str + "\n")
            
        f.close()
    
    return
        
def output_norm(sf, name, error):
    fname = name + ".norm"
    try:
        f = open(fname, 'w')
    except IOError:
        error("Cannot create report file %s !\n" % (fname))
        sys.exit(1)
    else:
        cnt = 1
        for b in sf:
            desc = "treat" + str(cnt)
            for i in range(len(b)):
                desc += "\t"+str(round(b[i], 2)) 
            f.write(desc + "\n")
            cnt +=1
        f.close()


def write_R_code(f_cnt_tbl, tfiles, cfiles, prj_name, norm, pval, fc, rpm_val):

    # Assembling R-code for analysis
    rscript = ''
    rscript += '\n'
    rscript += 'data <- read.table("%s",header=T,row.names=1)\n' % (f_cnt_tbl) # load counts table
    rscript += 'groups <- factor(c(rep("TGroup",%s),rep("CGroup",%s)))\n' % (len(tfiles),len(cfiles)) # generate groups for pairwise comparison

    # Quantile normalization to calculate fold change
    if norm == 'quant':
        rscript += 'colnum <- length(data)\n'
        rscript += 'rownum <- length(data[,1])\n'
        rscript += 'ordMatrix <- matrix(nrow=rownum,ncol=colnum)\n'
        rscript += 'ordIdx <- matrix(nrow=rownum,ncol=colnum)\n'
        rscript += 'for (i in 1:colnum){\n'
        rscript += '  a.sort <- sort(data[,i],index.return=T)\n'
        rscript += '  ordMatrix[,i] <- a.sort$x\n'
        rscript += '  ordIdx[,i] <- a.sort$ix\n'
        rscript += '}\n'
        rscript += 'rowAvg <- rowMeans(ordMatrix)\n'
        rscript += 'data.q.norm <- matrix(nrow=rownum,ncol=colnum)\n'
        rscript += 'for (i in 1:colnum){\n'
        rscript += '  data.q.norm[,i] <- rowAvg[order(ordIdx[,i])]\n'
        rscript += '}\n'
        rscript += 'colnames(data.q.norm) <- colnames(data)\n'
        rscript += 'rownames(data.q.norm) <- rownames(data)\n'
        if len(tfiles) > 1:
            rscript += 'sample1Mean <- rowMeans(data.q.norm[,1:%s],na.rm=T)\n' % (len(tfiles))
        else:
            rscript += 'sample1Mean <- data.q.norm[,1]\n'
        group2_start = len(tfiles) + 1
        group2_stop = group2_start + len(cfiles)
        if len(cfiles) > 1:
            rscript += 'sample2Mean <- rowMeans(data.q.norm[,%s:%s,na.rm=T)\n' % (group2_start, group2_stop)
        else:
            rscript += 'sample2Mean <- data.q.norm[,%s]\n' % (group2_start)
        rscript += 'FoldChange <- (sample2Mean/sample1Mean)\n'
        rscript += 'log2FoldChange <- log2(FoldChange)\n'

    # Normalize by RPM (reads per million mapped)
    else :
        rpm_vec = ','.join(str(x) for x in rpm_val)
        rscript += 'rpm <- c(%s)\n' % (rpm_vec)

    # Performing differential analysis using DESeq
    rscript += 'library(DESeq, quietly=T)\n'
    rscript += 'cds <- newCountDataSet(data,groups)\n'
    if norm == 'rpm':
        rscript += 'cds$sizeFactor = rpm\n'
    else:
        rscript += 'cds <- estimateSizeFactors(cds)\n'
    if(len(tfiles)==1 and len(cfiles) ==1):
        rscript += 'cds <- estimateDispersions(cds,method="blind",sharingMode="fit-only",fitType="local")\n'
    else:
        rscript += 'cds <- estimateDispersions(cds)\n'
    rscript += 'res <- nbinomTest(cds,"CGroup","TGroup")\n'

    # Generating output table
    if norm == 'quant':
        rscript += 'res_fc <- cbind(res$id,sample1Mean,sample2Mean,FoldChange,log2FoldChange,res$pval,res$padj)\n'
        rscript += 'colnames(res_fc) = c("id","sample1Mean","sample2Mean","FoldChange","log2FoldChange","pval", "padj")\n'
    else:
        rscript += 'res_fc <- res\n'
    rscript += 'write.table(res_fc, file="%s_gene_TE_analysis.txt", sep="\\t",quote=F,row.names=F)\n' % (prj_name)

    # Generating table of "significant" results

    l2fc = math.log(fc,2)
    if norm == 'quant':
        rscript += 'resSig <- res_fc[(!is.na(res_fc[,7]) & (res_fc[,7] < %f) & (abs(as.numeric(res_fc[,5])) > %f)), ]\n' % (pval, l2fc)
    else:
        rscript += 'resSig <- res_fc[(!is.na(res_fc$padj) & (res_fc$padj < %f) & (abs(res_fc$log2FoldChange)> %f)), ]\n' % (pval, l2fc)
    rscript += 'write.table(resSig, file="%s_sigdiff_gene_TE.txt",sep="\\t", quote=F, row.names=F)\n' % (prj_name)

    return rscript

# Main function of script
def main():
    """Start TEtranscripts......
parse options......
"""

args=read_opts2(prepare_parser())

info = args.info
#warn = args.warn
#debug = args.debug
error = args.error

# Output arguments used for program
info("\n" + args.argtxt + "\n")

info("Processing GTF files ...\n")
#(features, te_features, counts, te_category) = read_features(args.gtffile, args.tefile, args.stranded, "exon", "gene_id", args.te_mode)

(features, counts) = read_features(args.gtffile, args.stranded, "exon", "gene_id", args.te_mode)


#TE index

try :
    teIdx = TEfeatures()
    cur_time = time.time()
    te_tmpfile = '.'+str(cur_time)+'.te.gtf'
    subprocess.call(['sort -k 1,1 -k 4,4g '+ args.tefile+ ' >'+ te_tmpfile],shell=True )
    info("\nBuilding TE index .......\n")
    teIdx.build(te_tmpfile,args.te_mode)
    subprocess.call(['rm -f ' + te_tmpfile ],shell=True)
except :
    sys.stderr.write("Error in building TE index \n")
    sys.exit(1)

# Read sample files make count table

info("\nReading sample files ...\n")
(tsamples_tbl, tsamples_rpm) = count_reads(args.tfiles, args.parser, features, teIdx, counts, args.stranded, args.te_mode, args.nosort, args.prj_name,args.numItr)

(csamples_tbl, csamples_rpm) = count_reads(args.cfiles, args.parser, features, teIdx, counts, args.stranded, args.te_mode, args.nosort, args.prj_name,args.numItr)

info("Finished processing samples files")
info("Generating counts table")

f_cnt_tbl = args.prj_name + ".cntTable"
output_count_tbl(args.tfiles,args.cfiles,tsamples_tbl, csamples_tbl, f_cnt_tbl)
rpm_val = tsamples_rpm + csamples_rpm

#info("Calculating differential expression ...\n")

# Obtaining R-code for differential analysis
rscript = write_R_code(f_cnt_tbl, args.tfiles, args.cfiles, args.prj_name, args.norm, args.pval, args.fc, rpm_val)
f_rscript = args.prj_name + '_DESeq.R'
rcode = open('%s' % (f_rscript) , 'w')
rcode.write(rscript)
rcode.close()

# Running R-code for differential analysis
try:
	sts = subprocess.call(['Rscript', f_rscript])
except:
	error("Error in running differential analysis!\n")
        error("Error: %s\n" % str(sys.exc_info()[1]))
        error( "[Exception type: %s, raised in %s:%d]\n" % 
        	( sys.exc_info()[1].__class__.__name__, 
        		os.path.basename(traceback.extract_tb( sys.exc_info()[2] )[-1][0]), 
        		traceback.extract_tb( sys.exc_info()[2] )[-1][1] ) )
        sys.exit(1)
info("Done \n")


if __name__ == '__main__':
	try:
		main()
	except KeyboardInterrupt:
		sys.stderr.write("User interrupt !\n")
        sys.exit(0)
