'''
 Functions to use sequenza, the R-package, from command line (no need of R console)
'''

import os, gc
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
sequenza = importr("sequenza")


def RPy2sqeezeABfreq(abfreq, loop, tag, out):
   """
    Process an abfreq file and store the relvant (small and easy to re-acces) information
   """
   is_gz = False
   if os.path.split(abfreq)[1][-2:] == "gz":
      is_gz = True
   if loop:
      print "lading GC-content into memory"
      gc_stats = sequenza.gc_sample_stats(abfreq)
      chr_vect = robjects.r('as.character')(gc_stats.rx2('file.metrics').rx2('chr'))
      gc_vect  = robjects.r.setNames(gc_stats.rx2('raw').rx(True, '50%'), gc_stats.rx2('gc.values'))
   else:
      chr_vect = list()
      chr_vect.append(1)
   windows_baf   = robjects.ListVector({})
   windows_ratio = robjects.ListVector({})
   mutation_list = robjects.ListVector({})
   segments_list = robjects.ListVector({})
   for chr in chr_vect:
      if loop:
         file_lines = gc_stats.rx2('file.metrics').rx(chr, True)
         abf_data = sequenza.read_abfreq(abfreq, gz = is_gz, n_lines = robjects.IntVector((file_lines.rx('start')[0][0],  file_lines.rx('end')[0][0])))
      else:
         print "lading all the file in memory"
         abf_data = sequenza.read_abfreq(abfreq, fast = True, gz = is_gz)
         gc_stats = sequenza.gc_norm(ratio = abf_data.rx(True, 'depth.ratio'), gc = abf_data.rx(True, 'GC.percent'))
         gc_vect  = robjects.r.setNames(gc_stats.rx2('raw').rx(True, '50%'), gc_stats.rx2('gc.values'))
         chr_vect = robjects.r.unique(abf_data.rx(True, 'chromosome'))
         chr_vect = robjects.r('as.character')(chr_vect)   
      abf_data = robjects.r.cbind(abf_data, good_s_reads   = abf_data.rx(True, 'depth.sample').ro * abf_data.rx(True, 'sample.reads.above.quality'))
      abf_data.names[-1] = 'good.s.reads'
      abf_data = robjects.r.cbind(abf_data, adjusted_ratio = robjects.r.round(abf_data.rx(True, 'depth.ratio').ro / gc_vect.rx(robjects.r('as.character')(abf_data.rx(True, 'GC.percent'))), 3))
      abf_data.names[-1] = 'adjusted.ratio'
      abf_hom  = abf_data.rx(True, 'ref.zygosity').ro == 'hom'
      abf_het  = abf_data.rx(abf_hom.ro != True, True)
      abf_r_win = sequenza.windowValues(x = abf_data.rx(True, 'adjusted.ratio'), positions = abf_data.rx(True, 'n.base'), chromosomes = abf_data.rx(True, 'chromosome'), window = 1e6, overlap = 1, weight = abf_data.rx(True, 'depth.normal'))
      abf_b_win = sequenza.windowValues(x = abf_het.rx(True, 'Bf'), positions = abf_het.rx(True, 'n.base'), chromosomes = abf_het.rx(True, 'chromosome'), window = 1e6, overlap = 1, weight = robjects.r.round(abf_het.rx(True, 'good.s.reads'), 0))   
      breaks = sequenza.find_breaks(abf_het, gamma = 80, kmin = 10, baf_thres = robjects.FloatVector((0, 0.5)))
      seg_s1 = sequenza.segment_breaks(abf_data, breaks = breaks)
      mut_tab = sequenza.mutation_table(abf_data, mufreq_treshold = 0.10, min_reads = 40, max_mut_types = 1, min_type_freq = 0.9, segments = seg_s1)
      gc.collect()
      if loop:
         windows_baf.rx2[chr]   = abf_b_win.rx2(1)
         windows_ratio.rx2[chr] = abf_r_win.rx2(1)
         mutation_list.rx2[chr] = mut_tab
         segments_list.rx2[chr] = seg_s1
      else:
         windows_baf   = abf_b_win
         windows_ratio = abf_r_win
         for chr in chr_vect:
            mutation_list.rx2[chr] = mut_tab.rx(mut_tab.rx(True, 'chromosome').ro == chr, True)
            segments_list.rx2[chr] = seg_s1.rx(seg_s1.rx(True, 'chromosome').ro == chr, True)
   windows_baf.names   = chr_vect 
   windows_ratio.names = chr_vect 
   mutation_list.names = chr_vect 
   segments_list.names = chr_vect 
   subdir = out +'/' + tag
   if not os.path.exists(subdir):
      os.makedirs(subdir)
   robjects.r('write.table')(x = gc_stats.rx2('raw'), file = subdir +'/' + tag + '_raw_GC.txt', col_names = True, row_names = False, sep = "\t")
   robjects.r('write.table')(x = gc_stats.rx2('adj'), file = subdir +'/' + tag + '_adj_GC.txt', col_names = True, row_names = False, sep = "\t")
   robjects.r('assign')(x = tag + '_windows_Bf', value = windows_baf)
   robjects.r('save')(list = tag + '_windows_Bf', file = subdir + '/' + tag + '_windows_Bf.Rdata')
   robjects.r('assign')(x = tag + '_windows_ratio', value = windows_ratio)
   robjects.r('save')(list = tag + '_windows_ratio', file = subdir + '/' + tag + '_windows_ratio.Rdata')
   robjects.r('assign')(x = tag + '_mutation_list', value = mutation_list)
   robjects.r('save')(list = tag + '_mutation_list', file = subdir + '/' + tag + '_mutation_list.Rdata')
   robjects.r('assign')(x = tag + '_segments_list', value = segments_list)
   robjects.r('save')(list = tag + '_segments_list', file = subdir + '/' + tag + '_segments_list.Rdata')

def RPy2doAllSequenza(data_dir, is_male = True, tag = None):
   '''
   Load the information stored from the functions above and infer cellularit/ploidy 
   and save plots and table of mutations and CNV calls
   '''
   if tag == None:
      tag = os.path.split(data_dir)[-1]
   obj_list = robjects.StrVector(('adj_GC.txt', 'raw_GC.txt', 'windows_Bf.Rdata', 'mutation_list.Rdata', 'windows_ratio.Rdata', 'segments_list.Rdata'))
   obj_list = robjects.r.paste(tag, obj_list, sep = "_")
   gc_tab   = robjects.r('read.table')(data_dir +'/' + obj_list[0], header = True, sep = '\t')
   avg_depth_ratio = robjects.r.mean(gc_tab.rx(True, 2))
   for i in range(2,6):
      robjects.r.load(data_dir +'/' + obj_list[i])
   windows_Bf    = robjects.r.eval(robjects.r('as.name')(tag + '_windows_Bf'))
   windows_ratio = robjects.r.eval(robjects.r('as.name')(tag + '_windows_ratio'))
   mutation_list = robjects.r.eval(robjects.r('as.name')(tag + '_mutation_list'))
   segments_list = robjects.r.eval(robjects.r('as.name')(tag + '_segments_list'))
   chr_vect      = windows_Bf.names
   segs_all      = robjects.r('do.call')('rbind', segments_list)
   segs_len      = segs_all.rx(True, 'end.pos').ro - segs_all.rx(True, 'start.pos')
   segs_filt     = segs_len.ro >= 10e6
   if is_male:
      segs_is_xy = segs_all.rx(True, 'chromosome').ro == "X" or segs_all.rx(True, 'chromosome').ro == "Y"
   else:
      segs_is_xy = segs_all.rx(True, 'chromosome').ro == "Y"   
   filt_test  = segs_is_xy.ro == False
   filt_test  = segs_filt.ro & filt_test
   seg_test   = segs_all.rx(filt_test, True)
   weights_seg = robjects.r.round(segs_len.ro / 1e6, 0).ro + 150
   robjects.r('''
   wrapBafBayes <- function (Bf, depth_ratio , weight_ratio, weight_Bf, 
                             avg_depth_ratio,  cellularity, priors_label, priors_value,
                             dna_content, mc_cores, ...) {
                    baf.model.fit(Bf = Bf, depth.ratio = depth_ratio,
                    weight.ratio = weight_ratio, weight.Bf = weight_Bf,
                    avg.depth.ratio = avg_depth_ratio, cellularity = cellularity, 
                    priors.label = priors_label, priors.value = priors_value ,
                    dna.content = dna_content, mc.cores = mc_cores)
   }
   ''')
   CP  = robjects.r('wrapBafBayes')(Bf = seg_test.rx(True, 'Bf'), depth_ratio = seg_test.rx(True, 'depth.ratio'),
                    weight_ratio = weights_seg.rx(filt_test).ro * 2,
                    weight_Bf = weights_seg.rx(filt_test), avg_depth_ratio = avg_depth_ratio,
                    cellularity = robjects.r.seq(0.1, 1, 0.01) , priors_label = 2, priors_value = 2,
                    dna_content = robjects.r.seq(0.5, 3, 0.05), mc_cores = 16, ratio_priority = False)
   
   dna_c_cint = sequenza.get_ci(CP.rx(True, robjects.StrVector(('dna.content', 'L'))))
   cellu_cint = sequenza.get_ci(CP.rx(True, robjects.StrVector(('cellularity', 'L'))))

   robjects.r.pdf(data_dir +'/'+ tag + "_CP_ci.pdf")
   robjects.r.par(mfrow = robjects.IntVector((2, 2)))
   sequenza.cp_plot(CP)
   robjects.r.plot(cellu_cint.rx2('values').rx(True, robjects.IntVector((2, 1))), ylab = "Cellularity", xlab = "Likelihood", type = "l")
   robjects.r.abline(h = cellu_cint.rx2('confint'), lty = 2, lwd = 0.5, col = "red")  
   robjects.r.plot(dna_c_cint.rx2('values').rx(True, robjects.IntVector((1, 2))), ylab = "DNA-content", xlab = "Likelihood", type = "l")
   robjects.r.abline(v = dna_c_cint.rx2('confint'), lty = 2, lwd = 0.5, col = "red")
   robjects.r('dev.off()')

   if is_male:
      seg_res  = sequenza.baf_bayes(Bf = segs_all.rx(True, 'Bf').rx(segs_is_xy.ro == False),
                         depth_ratio = segs_all.rx(True, 'depth.ratio').rx(segs_is_xy.ro == False),
                         avg_depth_ratio = avg_depth_ratio,
                         weight_ratio = 2*200,
                         weight_Bf = 200,
                         cellularity = cellu_cint.rx2('max.l'),
                         dna_content = dna_c_cint.rx2('max.l'), CNr = 2)
      seg_res_xy  = sequenza.baf_bayes(Bf = segs_all.rx(True, 'Bf').rx(segs_is_xy),
                         depth_ratio = segs_all.rx(True, 'depth.ratio').rx(segs_is_xy),
                         avg_depth_ratio = avg_depth_ratio,
                         weight_ratio = 2*200,
                         weight_Bf = 200,
                         cellularity = cellu_cint.rx2('max.l'),
                         dna_content = dna_c_cint.rx2('max.l'), CNr = 1)   
      seg_res = robjects.r.cbind(robjects.r.rbind(segs_all.rx(segs_is_xy.ro == False,True),segs_all.rx(segs_is_xy, True)), robjects.r.rbind(seg_res, seg_res_xy))
   else:
      seg_res  = sequenza.baf_bayes(Bf = segs_all.rx(True, 'Bf'),
                         depth_ratio = segs_all.rx(True, 'depth.ratio'),
                         avg_depth_ratio = avg_depth_ratio,
                         weight_ratio = 2*200,
                         weight_Bf = 200,
                         cellularity = cellu_cint.rx2('max.l'),
                         dna_content = dna_c_cint.rx2('max.l'), CNr = 2)
      seg_res = robjects.r.cbind(segs_all, seg_res)
   robjects.r('write.table')(seg_res, data_dir +'/'+ tag + "_segments.txt", col_names = True, row_names = False, sep = "\t")


   robjects.r.pdf(data_dir +'/'+ tag + "_chromosome_view.pdf")
   for chrom in chr_vect:
      if is_male:
         if chrom == "X" or chrom == "Y":
            CNr = 1
         else:
            CNr = 2
      else:
         CNr = 2
      sequenza.chromosome_view(baf_windows = windows_Bf.rx2(chrom),
                      ratio_windows = windows_ratio.rx2(chrom), min_N_ratio = 1,
                      cellularity = cellu_cint.rx('max.l')[0][0], dna_content = dna_c_cint.rx('max.l')[0][0],
                      segments = seg_res.rx(seg_res.rx(True, 'chromosome').ro == chrom, True), mut_tab = mutation_list.rx2(chrom),
                      main = chrom, avg_depth_ratio = avg_depth_ratio, CNr = CNr)
   robjects.r('dev.off()')
   res_seg_xy = seg_res.rx(True, 'chromosome').ro == 'Y'


   barscn = robjects.DataFrame({'size' : (seg_res.rx(True, 'end.pos').rx(res_seg_xy.ro == False).ro - seg_res.rx(True, 'start.pos').rx(res_seg_xy.ro == False)),
                                'CNt' : seg_res.rx(True, 'CNt').rx(res_seg_xy.ro == False)})
   cn_sizes = robjects.r.split(barscn.rx(True,'size'),barscn.rx(True,'CNt'))
   cn_sizes = robjects.r.sapply(cn_sizes, 'sum')
   robjects.r.pdf(data_dir +'/'+ tag + "_CN_bars.pdf")
   robjects.r.barplot(robjects.r.round((cn_sizes.ro/robjects.r.sum(cn_sizes)).ro * 100, 0), names = cn_sizes.names, las = 1, 
           ylab = "Percentage (%)", xlab = "copy number")
   robjects.r('dev.off()')
   res_ci_tab = robjects.DataFrame({'cellularity' : robjects.FloatVector((cellu_cint.rx2('confint')[0], cellu_cint.rx2('max.l')[0], cellu_cint.rx2('confint')[1])),
                                     'dna.content' : robjects.FloatVector((dna_c_cint.rx2('confint')[0], dna_c_cint.rx2('max.l')[0], dna_c_cint.rx2('confint')[1])),
                                     'ploidy'      : robjects.r('weighted.mean')(x=robjects.r('as.integer')(cn_sizes.names), w = cn_sizes)})
   robjects.r('write.table')(res_ci_tab, data_dir +'/'+ tag + "_confints_CP.txt", col_names = True, row_names = False, sep = "\t")


