#!/usr/bin/env python
'''
Created on Jan 27, 2014

@author: Ying Jin
@contact: yjin@cshl.edu
@author: Oliver Tam
@contact tam@cshl.edu
@status: 
@version: 1.0.0
'''
# python module
import sys
import logging
import re
import math
import argparse
import optparse, itertools, warnings, traceback, os.path
import HTSeq
import subprocess

### Define parameters for program ###
def prepare_parser():
    desc = "Identifying differential transcription of gene and transposable elements." 

    exmp = "Example: TEtranscripts --format BAM --mode uniq -t RNAseq1.bam RNAseq2.bam -c CtlRNAseq1.bam CtlRNAseq.bam  "
    
    parser = argparse.ArgumentParser(description=desc, epilog=exmp) #'Identifying differential transcription binding/histone modification sites.')

    parser.add_argument('-t','--treatment', metavar='treatment sample', dest='tfiles',nargs='+', required=True,
                        help='Sample files in group 1 (e.g. treatment/mutant)')
    parser.add_argument('-c','--control', metavar='control sample', dest='cfiles',nargs='+', required=True,
                        help='Sample files in group 2 (e.g. control/wildtype)')
    parser.add_argument('--GTF', metavar='genic-GTF-file', dest='gtffile', type=str, required=True,
                        help='GTF file for gene annotations')
    parser.add_argument('--TE', metavar='TE-GTF-file', dest='tefile', type=str, required=True,
                        help='GTF file for transposable element annotations')
    parser.add_argument('--format', metavar='input file format', dest='format', type=str, nargs='?', default='BAM', choices=['BAM','SAM'],
                        help='Input file format: BAM or SAM. DEFAULT: BAM')
    parser.add_argument('--stranded', metavar='option', dest='stranded', nargs='?', type=str, default="yes", choices=['yes','no','reverse'],
                        help='Is this a stranded library? (yes, no, or reverse). DEFAULT: yes.')
    parser.add_argument('--mode', metavar='TE counting mode', dest='te_mode', nargs='?', type=str, const="multi", default='uniq', choices=['uniq','sameFam','sameInst','multi'],
                        help='How to count TE: uniq (unique mappers only), sameFam (group TE by family), sameInst (assign to dominant TE), or multi (distribute among all alignments).\
                        DEFAULT: uniq')
    parser.add_argument('--project', metavar='name', dest='prj_name', nargs='?', default='TEtranscripts_out',
                        help='Name of this project. DEFAULT: TEtranscripts_out')
    parser.add_argument('-p', '--padj', metavar='pvalue', dest='pval', nargs='?', type=float, const=0.1, default=0.05,
                        help='FDR cutoff for significance. DEFAULT: 0.05')
    parser.add_argument('-f', '--foldchange', metavar='foldchange', dest='fc', nargs='?', type=float, const=2.0, default=1.0,
                        help='Fold-change ratio (absolute) cutoff for differential expression. DEFAULT: 1')
    parser.add_argument('-n', '--norm', metavar='normalization', dest='norm', nargs='?', default='rpm', choices=['rpm','quant'],
                        help='Normalization method : rpm (reads per million mapped), quant (quantile normalization). DEFAULT: rpm')
    parser.add_argument('--no-sort', dest='nosort', action="store_true",
                        help='Input file is not sorted by chromosome position.')
    parser.add_argument('--verbose', metavar='verbose', dest='verbose', type=int, nargs='?', default=2,
                        help='Set verbose level. 0: only show critical message, 1: show additional warning message, 2: show process information, 3: show debug messages. DEFAULT:2')

    return parser

### Read options from command line ###
def read_opts(parser):
    args = parser.parse_args()
    
    # Obtain & store list of files for group 1 (e.g. treatment/mutant)
    for i in range(len(args.tfiles)) :
        if not os.path.isfile(args.tfiles[i]) :
            logging.error("No such file: %s !\n" % (args.tfiles[i]))
            sys.exit(1)
    
    # Obtain & store list of files for group2 (e.g. control/wildtype)
    for i in range(len(args.cfiles)) :
        if not os.path.isfile(args.cfiles[i]) :
            logging.error("No such file: %s !\n" % (args.cfiles[i]))
            sys.exit(1)

    # Identify file format for subsequent processing (parsing)
    if args.format == "BAM" :
        args.parser = "BAM"
    elif args.format == "SAM" :
        args.parser = "SAM"
    else :
        logging.error("Does not support such file format: %s !\n" % (args.format))
        sys.exit(1)

    # What sort of RNA-Seq experiment (stranded or not)
    if args.stranded == "yes" :
        args.stranded = "yes"
    elif args.stranded == "no" :
        args.stranded = "no"
    elif args.stranded == "reverse" :
        args.stranded = "reverse"
    else :
        logging.error("Does not support such stranded value: %s !\n" % (args.stranded))
        sys.exit(1)

    # Method of assigning reads to annotation (gene or TE)
    if args.te_mode not in ['uniq', 'multi', 'sameFam','sameInst'] :
        logging.error("multi-mapper counting mode %s not supported\n" % (args.te_mode))
        parser.print_help()
        sys.exit(1)
    
    # Method of normalization (rpm or quantile)
    if args.norm not in ['quant','rpm'] :
        logging.error("normalization method %s not supported\n" % (args.norm))
        parser.print_help()
        sys.exit(1)
    
    # Cutoff for adjusted p-value
    if args.pval < 0 or args.pval > 1 :
        logging.error("p-value should be a value in [0,1]\n")
        sys.exit(1)            

    # Cutoff for fold change
    if args.fc == 0:
        logging.error("absolute fold change ratio cannot be zero\n")
        sys.exit(1)
    elif args.fc < 0:
        args.fc = -1.0 * args.fc
    elif args.fc < 1 :
        args.fc = 1.0/args.fc
    else:
        args.fc = 1.0 * args.fc

    if args.nosort:
        args.nosort=True
    else:
        args.nosort=False

    # Level of logging for tool
    logging.basicConfig(level=(4 - args.verbose) * 10,
        format='%(levelname)-5s @ %(asctime)s: %(message)s ',
        datefmt='%a, %d %b %Y %H:%M:%S',
        stream=sys.stderr,
        filemode="w"
        )
    
    args.error = logging.critical        # function alias
    args.warn = logging.warning
    args.debug = logging.debug
    args.info = logging.info
    
    args.argtxt = "\n".join((
        "# ARGUMENTS LIST:", \
                "# name = %s" % (args.prj_name), \
                "# treatment files = %s" % (args.tfiles), \
                "# control files = %s" % (args.cfiles), \
                "# GTF file = %s " % (args.gtffile), \
                "# TE file = %s " % (args.tefile), \
                "# multi-mapper mode = %s " % (args.te_mode), \
                "# stranded = %s " % (args.stranded), \
                "# normalization = %s (rpm: Reads Per Million mapped; quant: Quantile normalization)" % (args.norm), \
                "# FDR cutoff = %.2e" % (args.pval), \
                "# fold-change cutoff = %5.2f" % (args.fc), \
                "# Alignments grouped by read ID = %s\n" % (args.nosort)
        ))
    
    return args 

class UnknownChrom(Exception):
    pass

# Reading & processing annotation files
def read_features(gff_filename, te_filename, stranded, feature_type, id_attribute, te_mode) :

    features = HTSeq.GenomicArrayOfSets("auto", stranded != "no")    
    te_features = HTSeq.GenomicArrayOfSets("auto", stranded != "no") 
    counts = {}
    te_category = {}

    # read count of features in GTF file   
    gff = HTSeq.GFF_Reader(gff_filename)   
    i = 0
    try:
        for f in gff:
            if f.type == feature_type:
                try:
                    feature_id = f.attr[ id_attribute ]

                except KeyError:
                    sys.exit("Feature %s does not contain a '%s' attribute" % (f.name, id_attribute))
                if stranded != "no" and f.iv.strand == "." :
                    sys.exit("Feature %s at %s does not have strand information but you are running in stranded mode. Use '--stranded=no'." % (f.name, f.iv))
                features[ f.iv ] += feature_id
                counts[ f.attr[ id_attribute ] ] = 0
            i += 1
            if i % 100000 == 0 :
                sys.stderr.write("%d GTF lines processed.\n" % i)
    except:
        sys.stderr.write("Error occured in %s.\n" % gff.get_line_number_string())
        raise

    if len(counts) == 0 :
        sys.stderr.write("Warning: No features of type '%s' found in gene GTF file.\n" % feature_type)
    
    
    #read counts of TE
  
    te_gff = HTSeq.GFF_Reader(te_filename)   
    i = 0
    try:
        for f in te_gff:
            if f.type == feature_type:
                try:
                    te_feature_id = f.attr[ id_attribute ]
                    te_family = f.attr[ "family_id" ]
                    te_family = re.sub(r'\?',r'',te_family)
                except KeyError:
                    sys.exit("One feature in %s does not contain a '%s' attribute. \n" % (te_gff.get_line_number_string(), id_attribute))
                    
                te_features[ f.iv ] += te_feature_id
                if te_mode == 'sameFam':
                    counts[ te_family ] = 0
                else:
                    counts[ f.attr[ id_attribute ] ] = 0
                te_category[ f.attr[ id_attribute ] ] = te_feature_id + ";" + te_family
            i += 1
            if i % 100000 == 0 :
                   sys.stderr.write("%d TE GTF lines processed.\n" % i)
    except:
        sys.stderr.write("Error occured in %s.\n" % te_gff.get_line_number_string())
        raise
    
    
    if len(counts) == 0 :
        sys.stderr.write("Warning: No features of type '%s' found in TE GTF file.\n" % id_attribute)
       
    return (features, te_features, counts, te_category)


# Reading files containing alignments (BAM of SAM)
def count_reads(samples, format, features, te_features,counts,te_category, stranded, te_mode, nosort, prj_name):
    
    cnt_tbl = {}
    rpm_val = []

    warnings.showwarning = my_showwarning
    try:
       # Try to open files, which causes the program to fail early if they are not there
        for filename in samples :
            if filename != "-":
                open(filename).close()

    except:
        sys.stderr.write("Error: Cannot open %s\n" % (filename))
        sys.stderr.write( "[Exception type: %s, raised in %s:%d]\n" % 
                          ( sys.exc_info()[1].__class__.__name__, 
                            os.path.basename(traceback.extract_tb( sys.exc_info()[2] )[-1][0]), 
                            traceback.extract_tb( sys.exc_info()[2] )[-1][1] ) )
        sys.exit(1)

    try:
        # Processing alignment files for annotation assignment
        for filename in samples :
            num_reads = 0
            if format == "BAM" :
                (num_reads) = count_SAMformat(filename,format,features,te_features,counts,te_category,stranded,te_mode, nosort, prj_name)
                '''
                if not nosort:
                    try:
                        subprocess.call('rm TEtranscripts_tmp.bam')
                    except:
                        sys.stderr.write("Failed to remove TEtranscripts_tmp.bam")
                '''
            else : # SAM format
                (num_reads) = count_SAMformat(filename,format,features,te_features,counts,te_category,stranded,te_mode,"False", prj_name)

            librpm = float(num_reads) / 1000000
            rpm_val.append(librpm)

        # Store count results of the file in 
            cnt_tbl[filename] = counts.copy()

        # Reset counts for gene/TE features
            for fn in sorted(counts.keys()):
                counts[fn] = 0 
    except:
        sys.stderr.write("Error: %s\n" % str(sys.exc_info()[1]))
        sys.stderr.write( "[Exception type: %s, raised in %s:%d]\n" % 
                          ( sys.exc_info()[1].__class__.__name__, 
                            os.path.basename(traceback.extract_tb( sys.exc_info()[2] )[-1][0]), 
                            traceback.extract_tb( sys.exc_info()[2] )[-1][1] ) )
        sys.exit(1)
       
    return (cnt_tbl, rpm_val)

def count_SAMformat(filename, format, features, te_features,counts,te_category, stranded, te_mode, nosort, prj_name):

    try:
        if format == "BAM" :
            if nosort :
                samtools_in = subprocess.Popen(["samtools", "view", filename], stdout=subprocess.PIPE)
            else :
                psort = subprocess.Popen(["samtools", "sort", "-n", "-o", filename, "%s_%s_tmp" % (filename, prj_name)], stdout=subprocess.PIPE)
                samtools_in = subprocess.Popen(["samtools", "view", "-"], stdin=psort.stdout, stdout=subprocess.PIPE)
                psort.stdout.close()  # Allow p to receive a SIGPIPE if psort exits.
#            samtools_in = subprocess.Popen(["samtools","view",filename], stdout=subprocess.PIPE)
            
            alignments = HTSeq.SAM_Reader(samtools_in.stdout)
        else :
            alignments = HTSeq.SAM_Reader(filename)

        first_read = iter(alignments).next()
        pe_mode = first_read.paired_end
        
    except:
        sys.stderr.write("Error occured when reading first line of sample file %s.\n" % filename)
        raise

    sys.stderr.write("Parsing sample file %s\n" % filename)

    try:
        if pe_mode:
            alignments_pe_file = alignments
            alignments = HTSeq.pair_SAM_alignments(alignments)
        empty = 0
        notaligned = 0
        lowqual = 0
        nonunique = 0

        i = 0
        annot_gene = []
        annot_TE = []
        no_annot = 0
        prev_read_name =''
        num_reads = 0
        prev_weight = 1
        
        for a in alignments:
            i += 1

            curr_read_name = ''

        # Single-end reads
            if not pe_mode:
                if not a.aligned:
                    notaligned += 1
                    continue
                iv_seq = []
                iv_seq.append(a.iv)
                weight = 1.0/(a.optional_field("NH"))
                try:
                    if a.optional_field("NH") > 1:
                        nonunique += weight
                        if te_mode == 'uniq' :
                            continue
                except KeyError:
                    pass
                curr_read_name = a.read.name
                        
        # Paired-end reads
            else:
                if a[0] is not None and a[0].aligned:
                    if stranded != "reverse":
                        iv_seq = (co.ref_iv for co in a[0].cigar if co.type == "M" and co.size > 0)
                    else:
                        iv_seq = (invert_strand(co.ref_iv) for co in a[0].cigar if co.type == "M" and co.size > 0)
                else:
                    iv_seq = tuple()
                    
                if a[1] is not None and a[1].aligned:
                    if stranded != "reverse":
                        iv_seq = itertools.chain(iv_seq, (invert_strand(co.ref_iv) for co in a[1].cigar if co.type == "M" and co.size > 0))
                    else:
                        iv_seq = itertools.chain(iv_seq, (co.ref_iv for co in a[1].cigar if co.type == "M" and co.size > 0))
                else:
                    if (a[0] is None) or not (a[0].aligned):
                        notaligned += 1
                        continue         
                try:
                       if (a[0] is not None and a[0].optional_field("NH") > 1) or \
                           (a[1] is not None and a[1].optional_field("NH") > 1):
                           nonunique += 1
                           if te_mode == 'uniq' :
                               continue
                    
                except KeyError:
                       pass
                
                curr_read_name = a[0].read.name
                if a[0].optional_field("NH") < a[1].optional_field("NH") :
                    weight = 1.0/(a[0].optional_field("NH"))
                else :
                    weight = 1.0/(a[1].optional_field("NH"))
    
            try:
                if prev_read_name == '' :
                    prev_read_name = curr_read_name
                if prev_read_name != curr_read_name:
                    try:
                        num_reads += 1
                        (empty) = parse_annotations(annot_gene, annot_TE, no_annot, counts, prev_weight, te_mode, empty, te_category, prev_read_name,filename)
                        annot_gene = []
                        annot_TE = []
                        no_annot = 0
                    except:
                        sys.stderr.write("Error occurred when processing annotations of %s in file %s.\n" % (prev_read_name, filename))
                        raise

                prev_weight = weight
                prev_read_name = curr_read_name
                fs = None

                for iv in iv_seq:
                    try:                    
                        for iv2, fs2 in features[ iv ].steps():
                            if len(fs2) > 0:
                                if fs is None:
                                    fs = fs2.copy()
                                else:
                                    fs = fs.intersection(fs2)
                    
                        if fs is not None and len(fs) > 0:
                            counter = 0
                            while(counter < len(fs)):
                                annotation = str((next(iter(fs))))
                                annot_gene.append(annotation)
                                counter += 1

                    except:
                        sys.stderr.write("Error occurred when assigning read (%s) to genic annotations in file %s.\n" % (curr_read_name, filename))
                        raise
                    
                    fs = None
                    for iv2, fs2 in te_features[ iv ].steps():
                        try:
                            if len(fs2) > 0 :
                                if fs is None or len(fs) == 0:
                                    fs = fs2.copy()
                                else:
                                    fs = fs.intersection(fs2)
                            if fs is not None and len(fs) > 0:
                                counter = 0
                                while(counter < len(fs)):
                                    annotation = str((next(iter(fs))))
                                    annot_TE.append(annotation)
                                    counter += 1

                            if fs is None or len(fs) == 0:
                                iv = invert_strand(iv)
                                for iv2, fs2 in te_features[ iv ].steps():
                                    if len(fs2) > 0:
                                        if fs is None or len(fs) == 0:
                                            fs = fs2.copy()
                                        else:
                                            fs = fs.intersection(fs2)
                                    
                            if fs is not None and len(fs) > 0:
                                counter = 0
                                while(counter < len(fs)):
                                    annotation = str((next(iter(fs))))
                                    annot_TE.append(annotation)
                                    counter += 1
                        except:
                            sys.stderr.write("Error occurred when assigning read (%s) to TE annotations in file %s.\n" % (curr_read_name, filename))
                            raise
                            
                    if fs is None or len(fs) == 0:
                        no_annot += 1

            except:
                sys.stderr.write("Error occurred during read assignments\n")
                raise
                                    
            if i % 1000000 == 0 :
                sys.stderr.write("%d %s processed.\n" % (i, "alignments " if not pe_mode else "alignment pairs"))

        try:
            if prev_read_name != '' :
                num_reads += 1
                (empty) = parse_annotations(annot_gene, annot_TE, no_annot, counts, prev_weight, te_mode, empty, te_category, prev_read_name, filename)
                annot_gene = []
                annot_TE = []
                no_annot = 0
        except:
            sys.stderr.write("Error occurred when assigning FINAL read (%s) to genic annotations in file %s.\n" % (prev_read_name, filename))
            raise

        sys.stderr.write("\nIn library %s:\n" % (filename))
        sys.stderr.write("Total mapped reads = %s\n" % ( str(num_reads) ) )
        sys.stderr.write("Total non-uniquely mapped reads = %s\n" % (str(int(nonunique))))
        sys.stderr.write("Total unannotated reads = %s\n\n" %(str(int(empty))))

    except:
        if not pe_mode:
            sys.stderr.write("Error occured in %s.\n" % alignments.get_line_number_string())
        else:
            sys.stderr.write("Error occured in %s.\n" % read_seq_pe_file.get_line_number_string())
        raise

    return int(num_reads)

# Invert strand of alignment
def invert_strand(iv):
    iv2 = iv.copy()
    if iv2.strand == "+":
           iv2.strand = "-"
    elif iv2.strand == "-":
           iv2.strand = "+"
    else:
           raise ValueError, "Illegal strand"
    return iv2


def parse_annotations(annot_gene, annot_TE, no_annot, counts, weight, te_mode, empty, te_category, seqname, filename):
    if weight == 1:
        if len(annot_gene) == 1:
            counts[annot_gene[0]] += weight
        elif len(annot_gene) > 1:
           counts =  resolve_ambiguity(annot_gene,counts,weight)
        elif len(annot_TE) == 1:
            if te_mode == 'sameFam':
                te_fam = getRepFam(annot_TE[0], te_category)
                counts[te_fam] += weight
            else:
                counts[annot_TE[0]] += weight
        elif len(annot_TE) > 1:
            resolve_TE_annot(annot_TE,counts,weight,te_mode, te_category)
        else:
            empty += 1
        return (empty)
    else:
        new_weight = 1 - (no_annot * weight)
        if new_weight < 0:
            sys.stderr.write("Weight is less than zero. Weight: %s. No annotations: %s. New weight %s\n" % (weight, no_annot, new_weight))
            sys.stderr.write("Annotation parsing error for %s in file %s\n" % (seqname, filename))
            sys.exit(1)
        if len(annot_TE) == 1:
            if te_mode == 'sameFam':
                te_fam = getRepFam(annot_TE[0], te_category)
                counts[te_fam] += new_weight
            else:
                counts[annot_TE[0]] += new_weight
        elif len(annot_TE) > 1:
            counts = resolve_TE_annot(annot_TE,counts,new_weight,te_mode, te_category)
        elif len(annot_gene) == 1:
            counts[annot_gene[0]] += new_weight
        elif len(annot_gene) > 1:
            counts = resolve_ambiguity(annot_gene,counts,new_weight)
        else:
            empty += (no_annot * new_weight)
        return (empty)

# Assign ambiguous genic reads mapped to multiple locations

def resolve_ambiguity(genes, counts, weight) :

    readslist = {}

    total = 0.0
    for g in genes :
        if g in counts :
            readslist[g] = counts[g]
            total += counts[g]
    
    if total > 0.0 :
        for g in genes :
            v = weight * readslist[g] / total
            counts[g] += v
    else :
        for g in genes :
            counts[g] = weight/len(genes)

    return counts


def getRepFam(te, te_category):
    te_name_fam = te_category[te].split(';')
    te_fam = te_name_fam[1]
    return te_fam

def resolve_TE_annot(annot_TE,counts,weight,te_mode, te_category):
    if te_mode == 'multi':
        for te in annot_TE:
            counts[te] += (weight / len(annot_TE))
    elif te_mode == 'uniq':
        resolve_ambiguity(annot_TE,counts,weight)
    else:
        te_counts = {}
        for te in annot_TE:
            if te_mode == 'sameFam':
                t_fam = getRepFam(te, te_category)
                if t_fam in te_counts:
                    te_counts[t_fam] += 1
                else:
                    te_counts[t_fam] = 1
            elif te_mode == 'sameInst':
                if te in te_counts:
                    te_counts[te] += 1
                else:
                    te_counts[te] = 1

        if te_mode == "sameFam" :
            (top_fam1, top_fam2, cnt1, cnt2) = majority(te_counts)
                        
            if cnt2 == 0 or cnt1 / cnt2 > 2.5 :
                counts[top_fam1] += weight
                
            else :
                counts[top_fam1] += (weight * (cnt1/(cnt1 + cnt2)))
                counts[top_fam2] += (weight * (cnt2/(cnt1 + cnt2)))

        elif te_mode == 'sameInst':
            (top_inst1, top_inst2, cnt1, cnt2) = majority(te_counts)
            if cnt2 == 0 or cnt1 / cnt2 > 2.5 :
                counts[top_inst1] += weight
                                        
            else :
                counts[top_inst1] += (weight * (cnt1/(cnt1 + cnt2)))
                counts[top_inst2] += (weight * (cnt2/(cnt1 + cnt2)))

    return counts
            


def majority(te_counts) :
    top_cnt1 = 0
    top_cnt2 = 0 
    top_fam1 = ""
    top_fam2 = ""
    
    for feature in te_counts.keys() :
        feat_cnt = te_counts[feature]
        if feat_cnt > top_cnt1 :
            top_fam1 = feature
            top_cnt1 = feat_cnt
            
        elif feat_cnt > top_cnt2 :
            top_fam2 = feature
            top_cnt2 = feat_cnt
    
    return (top_fam1, top_fam2, top_cnt1, top_cnt2)

def my_showwarning(message, category, filename, lineno=None, line=None):
    sys.stderr.write("Warning: %s\n" % message)


def output_res(res, ann, smps, prj):
    
    fname = prj+".png"
   
    plotHeatmap(res, ann, smps, fname)
    return 

def output_count_tbl(t_tbl, c_tbl, fname):
       
    try:
        f = open(fname, 'w')
    except IOError:
        error("Cannot create report file %s !\n" % (fname))
        sys.exit(1)
    else:
        cnt_tbl = {}
        header = "gene/TE"
        for tsmp in t_tbl.keys():
            cnts = t_tbl[tsmp]
            header +="\t"+tsmp+".T"
            for gene in sorted(cnts.keys()):
                if cnt_tbl.has_key(gene) :
                    cnt_tbl[gene].append(int(cnts[gene]))
                else :
                    val = []
                    val.append(int(cnts[gene]))
                    cnt_tbl[gene] = val
                    
        for csmp in c_tbl.keys():
            cnts = c_tbl[csmp]
            header +="\t"+csmp+".C"
            for gene in sorted(cnts.keys()):
                if cnt_tbl.has_key(gene) :
                    cnt_tbl[gene].append(int(cnts[gene]))
                else :
                    val = []
                    val.append(int(cnts[gene]))
                    cnt_tbl[gene] = val
        #output 
        f.write(header + "\n")
        for gene in cnt_tbl.keys() :
           vals = cnt_tbl[gene]
           vals_str = gene
           for i in range(len(vals)) :
              vals_str +="\t"+str(vals[i])                       
           f.write(vals_str + "\n")
            
        f.close()
    
    return
        
def output_norm(sf, name, error):
    fname = name + ".norm"
    try:
        f = open(fname, 'w')
    except IOError:
        error("Cannot create report file %s !\n" % (fname))
        sys.exit(1)
    else:
        cnt = 1
        for b in sf:
            desc = "treat" + str(cnt)
            for i in range(len(b)):
                desc += "\t"+str(round(b[i], 2)) 
            f.write(desc + "\n")
            cnt +=1
        f.close()


def write_R_code(f_cnt_tbl, tfiles, cfiles, prj_name, norm, pval, fc, rpm_val):

    # Assembling R-code for analysis
    rscript = ''
    rscript += '\n'
    rscript += 'data <- read.table("%s",header=T,row.names=1)\n' % (f_cnt_tbl) # load counts table
    rscript += 'groups <- factor(c(rep("G1",%s),rep("G2",%s)))\n' % (len(tfiles),len(cfiles)) # generate groups for pairwise comparison

    # Quantile normalization to calculate fold change
    if norm == 'quant':
        rscript += 'colnum <- length(data)\n'
        rscript += 'rownum <- length(data[,1])\n'
        rscript += 'ordMatrix <- matrix(nrow=rownum,ncol=colnum)\n'
        rscript += 'ordIdx <- matrix(nrow=rownum,ncol=colnum)\n'
        rscript += 'for (i in 1:colnum){\n'
        rscript += '  a.sort <- sort(data[,i],index.return=T)\n'
        rscript += '  ordMatrix[,i] <- a.sort$x\n'
        rscript += '  ordIdx[,i] <- a.sort$ix\n'
        rscript += '}\n'
        rscript += 'rowAvg <- rowMeans(ordMatrix)\n'
        rscript += 'data.q.norm <- matrix(nrow=rownum,ncol=colnum)\n'
        rscript += 'for (i in 1:colnum){\n'
        rscript += '  data.q.norm[,i] <- rowAvg[order(ordIdx[,i])]\n'
        rscript += '}\n'
        rscript += 'colnames(data.q.norm) <- colnames(data)\n'
        rscript += 'rownames(data.q.norm) <- rownames(data)\n'
        if len(tfiles) > 1:
            rscript += 'sample1Mean <- rowMeans(data.q.norm[,1:%s],na.rm=T)\n' % (len(tfiles))
        else:
            rscript += 'sample1Mean <- data.q.norm[,1]\n'
        group2_start = len(tfiles) + 1
        group2_stop = group2_start + len(cfiles)
        if len(cfiles) > 1:
            rscript += 'sample2Mean <- rowMeans(data.q.norm[,%s:%s,na.rm=T)\n' % (group2_start, group2_stop)
        else:
            rscript += 'sample2Mean <- data.q.norm[,%s]\n' % (group2_start)
        rscript += 'FoldChange <- (sample2Mean/sample1Mean)\n'
        rscript += 'log2FoldChange <- log2(FoldChange)\n'

    # Normalize by RPM (reads per million mapped)
    else :
        rpm_vec = ','.join(str(x) for x in rpm_val)
        rscript += 'rpm <- c(%s)\n' % (rpm_vec)

    # Performing differential analysis using DESeq
    rscript += 'library(DESeq, quietly=T)\n'
    rscript += 'cds <- newCountDataSet(data,groups)\n'
    if norm == 'rpm':
        rscript += 'cds$sizeFactor = rpm\n'
    else:
        rscript += 'cds <- estimateSizeFactors(cds)\n'
    if(len(tfiles)==1 and len(cfiles) ==1):
        rscript += 'cds <- estimateDispersions(cds,method="blind",sharingMode="fit-only",fitType="local")\n'
    else:
        rscript += 'cds <- estimateDispersions(cds)\n'
    rscript += 'res <- nbinomTest(cds,"G1","G2")\n'

    # Generating output table
    if norm == 'quant':
        rscript += 'res_fc <- cbind(res$id,sample1Mean,sample2Mean,FoldChange,log2FoldChange,res$pval,res$padj)\n'
        rscript += 'colnames(res_fc) = c("id","sample1Mean","sample2Mean","FoldChange","log2FoldChange","pval", "padj")\n'
    else:
        rscript += 'res_fc <- res\n'
    rscript += 'write.table(res_fc, file="%s_gene_TE_analysis.txt", sep="\\t",quote=F,row.names=F)\n' % (prj_name)

    # Generating table of "significant" results

    l2fc = math.log(fc,2)
    if norm == 'quant':
        rscript += 'resSig <- res_fc[(!is.na(res_fc[,7]) & (res_fc[,7] < %f) & (abs(as.numeric(res_fc[,5])) > %f)), ]\n' % (pval, l2fc)
    else:
        rscript += 'resSig <- res_fc[(!is.na(res_fc$padj) & (res_fc$padj < %f) & (abs(res_fc$log2FoldChange)> %f)), ]\n' % (pval, l2fc)
    rscript += 'write.table(resSig, file="%s_sigdiff_gene_TE.txt",sep="\\t", quote=F, row.names=F)\n' % (prj_name)

    return rscript

# Main function of script
def main():
    """Start TEDiffSeq......
       parse options......
    """
    
    args=read_opts(prepare_parser())
    
    info = args.info
    #warn = args.warn
    #debug = args.debug
    error = args.error

    # Output arguments used for program
    info("\n" + args.argtxt + "\n")

    info("Processing GTF files ...\n")
    (features, te_features, counts, te_category) = read_features(args.gtffile, args.tefile, args.stranded, "exon", "gene_id", args.te_mode)

    # Read sample files make count table
    info("\nReading sample files ...\n")
    (tsamples_tbl, tsamples_rpm) = count_reads(args.tfiles, args.parser, features, te_features, counts, te_category, args.stranded, args.te_mode, args.nosort, args.prj_name)
    
    (csamples_tbl, csamples_rpm) = count_reads(args.cfiles, args.parser, features, te_features, counts, te_category, args.stranded, args.te_mode, args.nosort, args.prj_name)

    info("Finished processing samples files")
    info("Generating counts table")
    
    f_cnt_tbl = args.prj_name + ".cntTable"
    output_count_tbl(tsamples_tbl, csamples_tbl, f_cnt_tbl)
    rpm_val = tsamples_rpm + csamples_rpm

    info("Calculating differential expression ...\n")

    # Obtaining R-code for differential analysis
    rscript = write_R_code(f_cnt_tbl, args.tfiles, args.cfiles, args.prj_name, args.norm, args.pval, args.fc, rpm_val)
    f_rscript = args.prj_name + '_DESeq.R'
    rcode = open('%s' % (f_rscript) , 'w')
    rcode.write(rscript)
    rcode.close()

    # Running R-code for differential analysis
    try:
        sts = subprocess.call(['Rscript', f_rscript])
    except:
        error("Error in running differential analysis!\n")
        error("Error: %s\n" % str(sys.exc_info()[1]))
        error( "[Exception type: %s, raised in %s:%d]\n" % 
                          ( sys.exc_info()[1].__class__.__name__, 
                            os.path.basename(traceback.extract_tb( sys.exc_info()[2] )[-1][0]), 
                            traceback.extract_tb( sys.exc_info()[2] )[-1][1] ) )
        sys.exit(1)

    info("Done \n")


if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        sys.stderr.write("User interrupt !\n")
        sys.exit(0)
        
