#!/usr/bin/env python2.6

'''
Created on Jan 27, 2014

@author: Ying Jin
@contact: yjin@cshl.edu
@author: Oliver Tam
@contact tam@cshl.edu
@status: 
@version: 1.0.0
'''
# python module
import sys
import logging
import re
import copy
import math
import operator
import argparse
import optparse, itertools, warnings, traceback, os.path

import subprocess
from time import time
from TEToolkit.IO.ReadInputs import read_opts2, read_features, SAM_Reader
from TEToolkit.TEindex import *
from TEToolkit.EMAlgorithm import *
from TEToolkit.IntervalTree import *

### Define parameters for program ###
def prepare_parser():
    desc = "Identifying differential transcription of gene and transposable elements." 

    exmp = "Example: TEtranscripts --format BAM --mode uniq -t RNAseq1.bam RNAseq2.bam -c CtlRNAseq1.bam CtlRNAseq.bam  "

   
    parser = argparse.ArgumentParser(prog='TEtranscripts',description=desc, epilog=exmp) #'Identifying differential transcription binding/histone modification sites.')

    parser.add_argument('-t','--treatment', metavar='treatment sample', dest='tfiles',nargs='+', required=True,
                        help='Sample files in group 1 (e.g. treatment/mutant)')
    parser.add_argument('-c','--control', metavar='control sample', dest='cfiles',nargs='+', required=True,
                        help='Sample files in group 2 (e.g. control/wildtype)')
    parser.add_argument('--GTF', metavar='genic-GTF-file', dest='gtffile', type=str, required=True,
                        help='GTF file for gene annotations')
    parser.add_argument('--TE', metavar='TE-GTF-file', dest='tefile', type=str, required=True,
                        help='GTF file for transposable element annotations')
    parser.add_argument('--format', metavar='input file format', dest='format', type=str, nargs='?', default='BAM', choices=['BAM','SAM'],
                        help='Input file format: BAM or SAM. DEFAULT: BAM')
    parser.add_argument('--stranded', metavar='option', dest='stranded', nargs='?', type=str, default="yes", choices=['yes','no','reverse'],
                        help='Is this a stranded library? (yes, no, or reverse). DEFAULT: yes.')
    parser.add_argument('--mode', metavar='TE counting mode', dest='te_mode', nargs='?', type=str, const="multi", default='uniq', choices=['uniq','sameEle','multi'],
                        help='How to count TE: uniq (unique mappers only), sameEle (assign to dominant TE), or multi (distribute among all alignments).\
                        DEFAULT: uniq')
    parser.add_argument('--project', metavar='name', dest='prj_name', nargs='?', default='TEtranscripts_out',
                        help='Name of this project. DEFAULT: TEtranscripts_out')
    parser.add_argument('-p', '--padj', metavar='pvalue', dest='pval', nargs='?', type=float, const=0.1, default=0.05,
                        help='FDR cutoff for significance. DEFAULT: 0.05')
    parser.add_argument('-f', '--foldchange', metavar='foldchange', dest='fc', nargs='?', type=float, const=2.0, default=1.0,
                        help='Fold-change ratio (absolute) cutoff for differential expression. DEFAULT: 1')
    parser.add_argument('--minread',metavar='min_read',dest='min_read',nargs='?',type=int,default=1,
                        help='read count cutoff. genes/TEs with reads less than the cutoff will not be considered.')
    parser.add_argument('-n', '--norm', metavar='normalization', dest='norm', nargs='?', default='DESeq_default', choices=['DESeq_default','TC','quant'],
                        help='Normalization method : DESeq_default (DEseq default normalization method), TC (total annotated counts), quant (quantile normalization). DEFAULT: DESeq_default')
    parser.add_argument('--sortByPos', dest = 'sortByPos', action="store_true",
                        help='Alignment files are sorted by chromosome position.')

    parser.add_argument('-i', '--iteration', metavar='iteration', dest='numItr', nargs='?', type=int,  default=0,
                        help='number of iteration to run the optimization. DEFAULT: 0')
    parser.add_argument('--verbose', metavar='verbose', dest='verbose', type=int, nargs='?', default=2,
                        help='Set verbose level. 0: only show critical message, 1: show additional warning message, 2: show process information, 3: show debug messages. DEFAULT:2')
    parser.add_argument('--version', action='version', version='%(prog)s 1.2.3')

    return parser



class UnknownChrom(Exception):
    pass

# Reading files containing alignments (BAM of SAM)
def count_reads(samples, format, features, teIdx,gene_counts,stranded, te_mode, sortByPos, prj_name,numItr):
    
    cnt_tbl = {}
    rpm_val = []

    warnings.showwarning = my_showwarning
    try:
       # Try to open files, which causes the program to fail early if they are not there
        for filename in samples :
            if filename != "-":
                open(filename).close()

    except:
        sys.stderr.write("Error: Cannot open %s\n" % (filename))
        sys.stderr.write( "[Exception type: %s, raised in %s:%d]\n" % 
                          ( sys.exc_info()[1].__class__.__name__, 
                            os.path.basename(traceback.extract_tb( sys.exc_info()[2] )[-1][0]), 
                            traceback.extract_tb( sys.exc_info()[2] )[-1][1] ) )
        sys.exit(1)

    try:
        # Processing alignment files for annotation assignment
         
        counts = dict(zip(gene_counts.keys(),[0]*len(gene_counts)))

        for filename in samples :
            num_reads = 0
            if format == "BAM" :
                (num_reads,te_instance_counts) = count_SAMformat(filename,format,features,teIdx,counts,stranded,te_mode, sortByPos, prj_name,numItr)

            else : # SAM format
                    (num_reads,te_instance_counts) = count_SAMformat(filename,format,features,teIdx,counts,stranded,te_mode,"False", prj_name,numItr)

           
            librpm = float(num_reads)
            rpm_val.append(librpm)
            
            te_ele_counts = groupByEle(te_instance_counts,teIdx)
            # Store count results of the file in 
            cnt_tbl[filename] = dict(counts.items() + te_ele_counts.items()) #copy()
                
            # Reset counts for gene/TE features
            counts = dict(zip(counts.keys(),[0]*len(counts)))
            
    except:
        sys.stderr.write("Error: %s\n" % str(sys.exc_info()[1]))
        sys.stderr.write( "[Exception type: %s, raised in %s:%d]\n" % 
                          ( sys.exc_info()[1].__class__.__name__, 
                            os.path.basename(traceback.extract_tb( sys.exc_info()[2] )[-1][0]), 
                            traceback.extract_tb( sys.exc_info()[2] )[-1][1] ) )
        sys.exit(1)
       
    return (cnt_tbl, rpm_val)


#group by TE element

def groupByEle(te_counts,teIdx) :
    te_ele_counts = {}
    for i in range(len(te_counts)) :
        ele_name = teIdx.getEleName(i)
        if ele_name is None: 
           sys.stderr.write("TE out of index boundary!\n")
           sys.exit(1)
        if ele_name in te_ele_counts :
            te_ele_counts[ele_name] += te_counts[i]
        else :
            te_ele_counts[ele_name] = te_counts[i]

    return te_ele_counts
 
#TE annotation
def TE_annotation(iv_seq,teIdx):
    #if len(iv_seq) > 2 :
    #    raise
    TEs = []
    for iv in iv_seq :
        chromo = iv[0]
        #start = iv.start - 100
        #end = iv.start + 100
        start = iv[1]
        end = iv[2]
        #strand = iv.strand
        name_idx_list  = teIdx.findOvpTE(chromo,start,end)
        if name_idx_list is not None :
            for t in name_idx_list :
                if t not in TEs :
                    TEs.append(t)
                        
    return TEs

_re_cigar_codes = re.compile( '([A-Z])' )

def cigar_interval(a):
    sc = _re_cigar_codes.split(a[5])
    pos = a[3]
    i_list = []
    for i in range(len(sc)):
        if i % 2 == 1:
            if sc[i] == 'M':
                newpos = pos + int(sc[i-1])
                i_list.append((a[2],pos,newpos-1))
                pos = newpos
            else:
                pos += int(sc[i-1])
    return i_list
        
#read assignment
def read_assignment(pe_mode,multi_reads,features,teIdx,stranded) :
    annot_gene = []
    annot_TE = []

    for aln in multi_reads :
        iv_seq = []
        inv_iv_seq = []
        if not pe_mode : #single end read
            iv_seq += cigar_interval(aln)
        else : #paired end read
            if aln[0] is not None and not (aln[0][1] & 0x0004):
                iv_seq += cigar_interval(aln[0])
            if aln[1] is not None and not (aln[1][1] & 0x0004):
                iv_seq += cigar_interval(aln[1])
        try:
            try:
                iv_seq = list(iv_seq)
                TEs = TE_annotation(iv_seq,teIdx)
                if len(TEs) > 0   :
                    annot_TE.append(TEs)
            
            except:
                raise
            genes = []
            fs = None
            for iv in iv_seq:
                try:
                    if iv[0] in features:
                        fs = features[iv[0]].find_gene(iv[1], iv[2])
                    else:
                        fs = []
                    if fs is not None and len(fs) > 0:
                        genes = genes + fs
                except:
                    raise
            if len(genes) > 0 :
                annot_gene.append(list(set(genes)))
        except:
            sys.stderr.write("Error occurred during read assignments\n")
            raise

    return (annot_gene,annot_TE)
				

def readInAlignment(filename, format, sortByPos,prj_name):
    try:
        if format == "BAM" :
            if not sortByPos  :
            
                samtools_in = subprocess.Popen(["samtools", "view", filename], stdout=subprocess.PIPE)
                alignments = SAM_Reader(samtools_in.stdout)
                iter_obj = iter(alignments)
                first_read = iter_obj.next()
                pe_mode = first_read[1] & 0x0001
            else :
                psort = subprocess.Popen(["samtools", "sort", "-n", "-o", filename, "%s_%s_tmp" % (filename, prj_name)], stdout=subprocess.PIPE)
                samtools_in = subprocess.Popen(["samtools", "view", "-"], stdin=psort.stdout, stdout=subprocess.PIPE)
                psort.stdout.close()  # Allow p to receive a SIGPIPE if psort exits.  
                alignments = SAM_Reader(samtools_in.stdout)
                iter_obj = iter(alignments)
                first_read = iter_obj.next()
                pe_mode = first_read[1] & 0x0001
        else :
                alignments = SAM_Reader(filename)
                iter_obj = iter(alignments)
                first_read = iter_obj.next()
                pe_mode = first_read[1] & 0x0001

        
    except:
        sys.stderr.write("Error occured when reading first line of sample file %s.\n" % filename)
        raise
    return (alignments,pe_mode,first_read,iter_obj)

#get read name of current alignment
def getCurReadName(a,pe_mode):
    # Single-end reads
    if not pe_mode:
        if a[1] & 0x0004:# notaligned += 1
            return ""
        return (a[0],len(a[8]))
    # Paired-end reads
    else:
                if ((a[0] is None) or (a[0][1] & 0x0004)) and ((a[1] is None) or (a[1][1] & 0x0004)):
                    #notaligned += 1
                    return ("",0)                    
                    # get the read name  
                if a[0] is not None and not (a[0][1] & 0x0004):
                       return (a[0][0],len(a[0][8]))
                                
                if a[1] is not None and not (a[1][1] & 0x0004):
                       return (a[1][0],len(a[1][8]))

def pe_which(flag):
    if flag & 0x0001:
        if flag & 0x0040:
            result = intern("first")
        elif flag & 0x0080:
            result = intern("second")
        else:
            result = intern("unknown")
    else:
        result = intern("not_paired_end")
    return result

def mate_aligned(flag):
    if flag & 0x0001:
        if flag & 0x0008:
            result = False
        else:
            result = True
    else:
        result = False
    return result

def get_iv_info(a):
    if a[1] & 0x0004:
        return (None,None)
    else:
        return (intern(a[2]),a[3])

def get_mate_info(a):
    if mate_aligned(a[1]):
        chrom = a[6]
        if chrom == "=":
            iv_chrom = get_iv_info(a)[0]
            if iv_chrom is not None:
                chrom = iv_chrom
            else:
                warnings.warn( "Malformed SAM line: MRNM == '=' although read is not aligned." )
        pos = a[3]
    else:
        chrom = None
        pos = None
    return(chrom,pos)

def pair_SAM_alignments( alignments, bundle=False ):

   mate_missing_count = [0]

   def process_list( almnt_list ):
      while len( almnt_list ) > 0:
         a1 = almnt_list.pop( 0 )
         # Find its mate
         for a2 in almnt_list:
            if pe_which(a1[1]) == pe_which(a2[1]):
               continue
            if (a1[1] & 0x0004) == mate_aligned(a2[1]) or mate_aligned(a1[1]) == (a2[1] & 0x0004):
               continue
            if not ((not (a1[1] & 0x0004)) and (not (a2[1] & 0x0004))):
               break
            if get_iv_info(a1)[0] == get_mate_info(a2)[0] and get_iv_info(a1)[1] == get_mate_info(a2)[1] and \
                  get_iv_info(a2)[0] == get_mate_info(a1)[0] and get_iv_info(a2)[1] == get_mate_info(a1)[1]:
               break
         else:
            if mate_aligned(a1[1]):
               mate_missing_count[0] += 1
               if mate_missing_count[0] == 1:
                  warnings.warn( "Read " + a1[0] + " claims to have an aligned mate " +
                     "which could not be found in an adjacent line." )
            a2 = None
         if a2 is not None:
            almnt_list.remove( a2 )
         if pe_which(a1[1]) == "first":
            yield ( a1, a2 )
         else:
            assert pe_which(a1[1]) == "second"
            yield ( a2, a1 )

   almnt_list = []
   current_name = None
   for almnt in alignments:
      if not (almnt[1] & 0x0001):
         raise ValueError, "'pair_alignments' needs a sequence of paired-end alignments"
      if pe_which(almnt[1]) == "unknown":
         raise ValueError, "Paired-end read found with 'unknown' 'pe_which' status."
      if almnt[0] == current_name:
         almnt_list.append( almnt )
      else:
         if bundle:
            yield list( process_list( almnt_list ) )
         else:
            for p in process_list( almnt_list ):
               yield p
         current_name = almnt[0]
         almnt_list = [ almnt ]
   if bundle:
      yield list( process_list( almnt_list ) )
   else:
      for p in process_list( almnt_list ):
         yield p
   if mate_missing_count[0] > 1:
      warnings.warn( "%d reads with missing mate encountered." % mate_missing_count[0] )
                       
def count_SAMformat(filename, format, features, teIdx,gene_counts, stranded, te_mode, sortByPos, prj_name,numItr):

    
    (alignments,pe_mode,first_read,iter_obj) = readInAlignment(filename, format, sortByPos,prj_name)
    

    sys.stderr.write("Parsing sample file %s\n" % filename)
    num_reads = sum(gene_counts.values())

    sys.stderr.write("Total mapped reads = %s\n" % ( str(num_reads) ) )
    
    empty = 0
    nonunique = 0
    i = 0
    pre_read_name =''
    num_reads = 0
    multi_reads = []
    alignments_per_read = []
    leftOver_gene = []
    leftOver_te = []
    avgReadLength =  0
    tmp = []
    te_counts = [0.0] * teIdx.numTEs()
    te_multi_counts = [0.0]*len(te_counts)

    try:
        if pe_mode:
            alignments_pe_file = alignments

            second_read = iter_obj.next()

            algn_first_pair = (first_read,second_read)
            (curr_read_name,curr_read_len) = getCurReadName(algn_first_pair,pe_mode)       
            pre_read_name = curr_read_name
            alignments_per_read.append(algn_first_pair)                    
            alignments = pair_SAM_alignments(alignments)

        for a in alignments:
            i += 1
            (curr_read_name,curr_read_len) = getCurReadName(a,pe_mode)       
            
            if len(tmp) < 10000 :
                tmp.append(curr_read_len)
                
            #first read
            if pre_read_name == '' :
                    pre_read_name = curr_read_name
            #mult-reads        
            if pre_read_name == curr_read_name and curr_read_name != '':
                    alignments_per_read.append(a)                    
            #new read
            else :
                
                if len(alignments_per_read) > 1 :
                    nonunique += 1          
                    if te_mode == 'uniq' :
                        empty += 1
                        alignments_per_read = []
                        pre_read_name = curr_read_name
                        alignments_per_read.append(a)               		            
                        continue            	
                                                                
                #weighting
                try:#read assignment
                    (annot_gene,annot_TE) = read_assignment(pe_mode,alignments_per_read, features, teIdx,stranded)

                    if len(alignments_per_read) > 1 : #multi read, prior to TE
                        no_annot_te = True
                        if len(annot_TE) > 0 :
                            no_annot_te = parse_annotations_TE(multi_reads,annot_TE, te_counts, te_multi_counts, teIdx,te_mode,leftOver_te)
                        if no_annot_te and len(annot_gene) > 0 :
                            no_annot_gene = parse_annotations_gene(annot_gene,gene_counts,leftOver_gene,te_mode)
                            if no_annot_gene :
                               empty += 1

                    else : #uniq read , prior to gene
                        no_annot_gene = True
                        if te_mode == "uniq" and (len(annot_gene) + len(annot_TE) >1) :
                           empty += 1
                        else :    
                          if len(annot_gene) > 0 :
                            no_annot_gene = parse_annotations_gene(annot_gene,gene_counts,leftOver_gene,te_mode )
                            if no_annot_gene :
                                no_annot_te = parse_annotations_TE(multi_reads,annot_TE, te_counts, te_multi_counts, teIdx,te_mode,leftOver_te)
                                if no_annot_te :
                                    empty += 1
                          else :
                                no_annot_te = parse_annotations_TE(multi_reads,annot_TE, te_counts, te_multi_counts, teIdx,te_mode,leftOver_te)
                                if no_annot_te :
                                    empty += 1                            
                   
                except:
                    sys.stderr.write("Error occurred when processing annotations of %s in file %s.\n" % (pre_read_name, filename))
                    raise
                
                if i % 1000000 == 0 :
                    sys.stderr.write("%d %s processed.\n" % (i, "alignments " ))

                alignments_per_read = []
                pre_read_name = curr_read_name
                alignments_per_read.append(a)                                   

        # the last read
        try:
            #ignore the last read
            if pre_read_name != '' :
                #num_reads += 1
                empty += 1 
        #        if (len(alignments_per_read) > 1 and 'uniq' not in te_mode) or len(alignments_per_read) == 1:
        #            nonunique += 1
        #            (annot_gene,annot_TE) = read_assignment(pe_mode,alignments_per_read, features, teIdx,stranded)
        #            if len(annot_TE) < len(annot_gene) :
        #                no_annot1 = parse_annotations_gene(annot_gene,gene_counts,leftOver_gene)
        #                if no_annot1 :
        #                    no_annot2 = parse_annotations_TE(multi_reads,annot_TE, te_counts, te_multi_counts, teIdx,te_mode,leftOver_te)
        #                    if no_annot2 :
        #                        empty += 1
        #            else :
        #                no_annot1 = parse_annotations_TE(multi_reads,annot_TE, te_counts, te_multi_counts, teIdx,te_mode,leftOver_te)
        #                if no_annot1  :
        #                    no_annot2 = parse_annotations_gene(annot_gene,gene_counts,leftOver_gene)
        #                    if no_annot2 :
        #                        empty += 1

            #caused by overlapping genes or TEs, resolve ambiguocity 
            if len(leftOver_gene) > 0 and te_mode != 'uniq':
                resolve_annotation_ambiguity(gene_counts,leftOver_gene)           
            if len(leftOver_te) > 0 and te_mode != 'uniq':
                resolve_annotation_ambiguity(te_counts,leftOver_te)

            if te_mode == 'uniq' :
                  empty += len(leftOver_gene) + len(leftOver_te)     		
            ss = sum(te_counts)
            sys.stderr.write("uniq te counts = %s \n" % (str(ss)))
            te_tmp_counts = [0]*len(te_counts)

            if numItr > 0 :
              try :
                ''' iterative optimization on TE reads '''
                sys.stderr.write(".......start iterative optimization ..........\n")
                avgReadLength = int(sum(tmp)/len(tmp))
                sys.stderr.write("...average read length %d \n" % (avgReadLength)) 
                new_te_multi_counts = [0] *len(te_counts)  
                if len(multi_reads) > 0 :
                    new_te_multi_counts = EMestimate(teIdx,multi_reads,te_tmp_counts,te_multi_counts,numItr,avgReadLength)
                    
              except :
                sys.stderr.write("Error in optimization\n")
                raise
              te_counts = map(operator.add,te_counts,new_te_multi_counts)
            else :
              te_counts = map(operator.add,te_counts,te_multi_counts)
            
            
        except:
            sys.stderr.write("Error occurred when assigning read gene/TE annotations in file %s.\n" % (filename))
            raise
        ss = sum(te_counts)
        num_reads = ss + sum(gene_counts.values())

        sys.stderr.write("te_counts total %s\n" % (ss))
        sys.stderr.write("\nIn library %s:\n" % (filename))
        sys.stderr.write("Total mapped reads = %s\n" % ( str(num_reads) ) )
        sys.stderr.write("Total non-uniquely mapped reads = %s\n" % (str(int(nonunique))))
        sys.stderr.write("Total unannotated reads = %s\n\n" %(str(int(empty))))

    except:
        if not pe_mode:
            sys.stderr.write("Error occured in %s.\n" % alignments.get_line_number_string())
        else:
            sys.stderr.write("Error occured in %s.\n" % alignments_pe_file.get_line_number_string())
        raise

    return (num_reads,te_counts)
				

# Invert strand of alignment
#def invert_strand(iv):
#    iv2 = iv.copy()
#    if iv2.strand == "+":
#           iv2.strand = "-"
#    elif iv2.strand == "-":
#           iv2.strand = "+"
#    else:
#           raise ValueError, "Illegal strand"
#    return iv2

    
def parse_annotations_gene(annot_gene, gene_counts,leftOver_gene,te_mode):

        no_annot = True

        if len(annot_gene) > 1 :
            leftOver_gene.append((annot_gene,1.0))
            if te_mode == "uniq" :
               return no_annot
            

        elif len(annot_gene) == 1 :
            genes = annot_gene[0]
            if len(genes) == 1 :
                gene_counts[genes[0]] += 1
            else :
            	if genes[0] == genes[1] :
            		gene_counts[genes[0]] += 1
            	else :
            		gene_counts[genes[0]] += 0.5
            		gene_counts[genes[1]] += 0.5
        else :
            return no_annot
        
        return False

# Assign ambiguous genic reads mapped to multiple locations
def resolve_annotation_ambiguity(counts, leftOvers ) :
  
    for (annlist,w) in leftOvers :
    	readslist = {}
    	total = 0.0
    	
        for ann in annlist :
            for a in ann:
            	if a not in readslist :
            		readslist[a] = counts[a]
            		total += counts[a]
    
	    if total > 0.0 :
	        for a in readslist :
	            v = w * readslist[a] / total
	            counts[a] += v
	    else :
	        for a in readslist :
	            counts[a] = w/len(readslist)

               
def parse_annotations_TE(multi_reads,annot_TE,uniq_counts,multi_counts,te_features, te_mode,leftOver_list):
    if len(annot_TE) == 0 :
        return True
    
    if len(annot_TE) == 1 and len(annot_TE[0]) == 1 :
            uniq_counts[annot_TE[0][0]] += 1
    else :
      
      if te_mode == 'multi':
            multi_algn = []
            for i in range(len(annot_TE)):
                for te in annot_TE[i] :
                    multi_counts[te] += 1.0 / (len(annot_TE) * len(annot_TE[i]))
                    multi_algn.append(te)
            
            multi_reads.append(multi_algn)
    		
      if te_mode == 'uniq':
    		leftOver_list.append((annot_TE,1.0))
    return False    

def my_showwarning(message, category, filename, lineno=None, line=None):
    sys.stderr.write("Warning: %s\n" % message)


def output_res(res, ann, smps, prj):
    
    fname = prj+".png"
   
    plotHeatmap(res, ann, smps, fname)
    return 

def output_count_tbl(t_tbl, c_tbl, fname):
       
    try:
        f = open(fname, 'w')
    except IOError:
        error("Cannot create report file %s !\n" % (fname))
        sys.exit(1)
    else:
        cnt_tbl = {}
        header = "gene/TE"
        keys = set([])
        for tsmp in t_tbl.keys():
            keys = keys.union(t_tbl[tsmp].keys())
            header +="\t"+tsmp+".T"
        for csmp in c_tbl.keys():
            keys = keys.union(c_tbl[csmp].keys())
            header +="\t"+csmp+".C"
          
        for tsmp in t_tbl.keys():
            cnts = t_tbl[tsmp]
            for k in keys:
                val = 0
                if k in cnts :
                   val = cnts[k]
                
                if cnt_tbl.has_key(k) :
                    cnt_tbl[k].append(int(val))
                else :
                    vallist = []
                    vallist.append(int(val))
                    cnt_tbl[k] = vallist
                    
        for csmp in c_tbl.keys():
            cnts = c_tbl[csmp]
            for k in keys:
                val = 0
                if k in cnts :
                   val = cnts[k]

                if cnt_tbl.has_key(k) :
                    cnt_tbl[k].append(int(val))
                else :
                    vallist = []
                    vallist.append(int(val))
                    cnt_tbl[gene] = vallist
        #output 
        f.write(header + "\n")
        for gene in cnt_tbl.keys() :
           vals = cnt_tbl[gene]
           vals_str = gene
           for i in range(len(vals)) :
              vals_str +="\t"+str(vals[i])                       
           f.write(vals_str + "\n")
            
        f.close()
    
    return
        
def output_norm(sf, name, error):
    fname = name + ".norm"
    try:
        f = open(fname, 'w')
    except IOError:
        error("Cannot create report file %s !\n" % (fname))
        sys.exit(1)
    else:
        cnt = 1
        for b in sf:
            desc = "treat" + str(cnt)
            for i in range(len(b)):
                desc += "\t"+str(round(b[i], 2)) 
            f.write(desc + "\n")
            cnt +=1
        f.close()


def write_R_code(f_cnt_tbl, tfiles, cfiles, prj_name, norm, pval, fc, rpm_val,min_read):

    # Assembling R-code for analysis
    rscript = ''
    rscript += '\n'
    rscript += 'data <- read.table("%s",header=T,row.names=1)\n' % (f_cnt_tbl) # load counts table
    rscript += 'groups <- factor(c(rep("TGroup",%s),rep("CGroup",%s)))\n' % (len(tfiles),len(cfiles)) # generate groups for pairwise comparison
    rscript += 'min_read <- %s\n' % (min_read)
    # Counts filtering (hard coded to 20)
    rscript += 'data <- data[apply(data,1,function(x){max(x)}) > min_read,]\n'

    # Quantile normalization to calculate fold change
    if norm == 'quant':
        rscript += 'colnum <- length(data)\n'
        rscript += 'rownum <- length(data[,1])\n'
        rscript += 'ordMatrix <- matrix(nrow=rownum,ncol=colnum)\n'
        rscript += 'ordIdx <- matrix(nrow=rownum,ncol=colnum)\n'
        rscript += 'for (i in 1:colnum){\n'
        rscript += '  a.sort <- sort(data[,i],index.return=T)\n'
        rscript += '  ordMatrix[,i] <- a.sort$x\n'
        rscript += '  ordIdx[,i] <- a.sort$ix\n'
        rscript += '}\n'
        rscript += 'rowAvg <- rowMeans(ordMatrix)\n'
        rscript += 'data.q.norm <- matrix(nrow=rownum,ncol=colnum)\n'
        rscript += 'for (i in 1:colnum){\n'
        rscript += '  data.q.norm[,i] <- rowAvg[order(ordIdx[,i])]\n'
        rscript += '}\n'
        rscript += 'colnames(data.q.norm) <- colnames(data)\n'
        rscript += 'rownames(data.q.norm) <- rownames(data)\n'
        if len(tfiles) > 1:
            rscript += 'sample1Mean <- rowMeans(data.q.norm[,1:%s],na.rm=T)\n' % (len(tfiles))
        else:
            rscript += 'sample1Mean <- data.q.norm[,1]\n'
        group2_start = len(tfiles) + 1
        group2_stop = group2_start + len(cfiles)
        if len(cfiles) > 1:
            rscript += 'sample2Mean <- rowMeans(data.q.norm[,%s:%s,na.rm=T)\n' % (group2_start, group2_stop)
        else:
            rscript += 'sample2Mean <- data.q.norm[,%s]\n' % (group2_start)
        rscript += 'FoldChange <- (sample2Mean/sample1Mean)\n'
        rscript += 'log2FoldChange <- log2(FoldChange)\n'

    # Normalize by RPM (reads per million mapped)
    if norm == 'TC'  :
        min_libSize = min(rpm_val)
        rpm_vec = ','.join(str(x/min_libSize) for x in rpm_val)
        rscript += 'tc <- c(%s)\n' % (rpm_vec)

        
    # Performing differential analysis using DESeq
    rscript += 'library(DESeq, quietly=T)\n'
    rscript += 'cds <- newCountDataSet(data,groups)\n'
    if norm == 'TC':
        rscript += 'cds$sizeFactor = tc\n'
    else:
        rscript += 'cds <- estimateSizeFactors(cds)\n'
    if(len(tfiles)==1 and len(cfiles) ==1):
        rscript += 'cds <- estimateDispersions(cds,method="blind",sharingMode="fit-only",fitType="local")\n'
    elif(len(tfiles) > 1 and len(cfiles) > 1):
        rscript += 'cds <- estimateDispersions(cds,method="per-condition")\n'
    else :
        rscript += 'cds <- estimateDispersions(cds,method="pooled")\n'
    
    rscript += 'res <- nbinomTest(cds,"CGroup","TGroup")\n'

    # Generating output table
    if norm == 'quant':
        rscript += 'res_fc <- cbind(res$id,sample1Mean,sample2Mean,FoldChange,log2FoldChange,res$pval,res$padj)\n'
        rscript += 'colnames(res_fc) = c("id","sample1Mean","sample2Mean","FoldChange","log2FoldChange","pval", "padj")\n'
    else:
        rscript += 'res_fc <- res\n'
    rscript += 'write.table(res_fc, file="%s_gene_TE_analysis.txt", sep="\\t",quote=F,row.names=F)\n' % (prj_name)

    # Generating table of "significant" results

    l2fc = math.log(fc,2)
    if norm == 'quant':
        rscript += 'resSig <- res_fc[(!is.na(res_fc[,7]) & (res_fc[,7] < %f) & (abs(as.numeric(res_fc[,5])) > %f)), ]\n' % (pval, l2fc)
    else:
        rscript += 'resSig <- res_fc[(!is.na(res_fc$padj) & (res_fc$padj < %f) & (abs(res_fc$log2FoldChange)> %f)), ]\n' % (pval, l2fc)
    rscript += 'write.table(resSig, file="%s_sigdiff_gene_TE.txt",sep="\\t", quote=F, row.names=F)\n' % (prj_name)

    return rscript

# Main function of script
def main():
    """Start TEtranscripts......
parse options......
"""

args=read_opts2(prepare_parser())

info = args.info
#warn = args.warn
#debug = args.debug
error = args.error

# Output arguments used for program
info("\n" + args.argtxt + "\n")

info("Processing GTF files ...\n")

(features, counts) = read_features(args.gtffile, args.stranded, "exon", "gene_id", args.te_mode)

#TE index
try :
    teIdx = TEfeatures()
    cur_time = time.time()
    te_tmpfile = '.'+str(cur_time)+'.te.gtf'
    subprocess.call(['sort -k 1,1 -k 4,4g '+ args.tefile+ ' >'+ te_tmpfile],shell=True )
    info("\nBuilding TE index .......\n")
    teIdx.build(te_tmpfile,args.te_mode)
    subprocess.call(['rm -f ' + te_tmpfile ],shell=True)
except :
    sys.stderr.write("Error in building TE index \n")
    sys.exit(1)

# Read sample files make count table

info("\nReading sample files ...\n")
(tsamples_tbl, tsamples_rpm) = count_reads(args.tfiles, args.parser, features, teIdx, counts, args.stranded, args.te_mode, args.sortByPos, args.prj_name,args.numItr)

(csamples_tbl, csamples_rpm) = count_reads(args.cfiles, args.parser, features, teIdx, counts, args.stranded, args.te_mode, args.sortByPos, args.prj_name,args.numItr)

info("Finished processing samples files")
info("Generating counts table")

f_cnt_tbl = args.prj_name + ".cntTable"
output_count_tbl(tsamples_tbl, csamples_tbl, f_cnt_tbl)
rpm_val = tsamples_rpm + csamples_rpm

#info("Calculating differential expression ...\n")

# Obtaining R-code for differential analysis
rscript = write_R_code(f_cnt_tbl, args.tfiles, args.cfiles, args.prj_name, args.norm, args.pval, args.fc, rpm_val,args.min_read)
f_rscript = args.prj_name + '_DESeq.R'
rcode = open('%s' % (f_rscript) , 'w')
rcode.write(rscript)
rcode.close()

# Running R-code for differential analysis
try:
	sts = subprocess.call(['Rscript', f_rscript])
except:
	error("Error in running differential analysis!\n")
        error("Error: %s\n" % str(sys.exc_info()[1]))
        error( "[Exception type: %s, raised in %s:%d]\n" % 
        	( sys.exc_info()[1].__class__.__name__, 
        		os.path.basename(traceback.extract_tb( sys.exc_info()[2] )[-1][0]), 
        		traceback.extract_tb( sys.exc_info()[2] )[-1][1] ) )
        sys.exit(1)
info("Done \n")


if __name__ == '__main__':
	try:
		main()
	except KeyboardInterrupt:
		sys.stderr.write("User interrupt !\n")
        sys.exit(0)
