#!/usr/bin/env python

from BNfinder.MDL import MDL
from BNfinder.MIT import MIT
from BNfinder.BDE import BDE
from BNfinder.data import dataset
from BNfinder import util
import copy,StringIO

def read_preds(str,dts):
    lines=[ln.strip().split("\t") for ln in str.strip().split("\n")]
    res={}
    for ln in lines[1:]:
        cond=ln[0]
        for v in dts.vertices:
            if v.name==cond:
                index=v.index
                cls=[p._data[0][index] for p in dts.points]
                pred=[float(x) for x in ln[1:]]
                res[cond]=pred,cls
                break
    return res
        
if __name__=="__main__":
    import sys
    #from dispatch import *
    #test(int(sys.argv[1]),int(sys.argv[2]),int(sys.argv[3]))
    from optparse import OptionParser
    parser = OptionParser()
    parser.add_option("-s", "--score", dest="score",default="BDE",help="Scoring function: BDE (default) or MDL or MIT")
    parser.add_option("-x", "--xaxis", dest="xax",default="fpr",help="y axis of the roc curve. tpr is default. see Rocr manual for posiible options")
    parser.add_option("-y", "--yaxis", dest="yax",default="tpr",help="x axis of the roc curve. tpr is default. see Rocr manual for posiible options")
    parser.add_option("-a", "--diag", dest="diag",type=int,default=1,help="Set to 0 to remove diagonal lines from roc plots")
    parser.add_option("-d", "--data-factor", dest="df",default=1.0,type=float,help="Factor multiplying the data set")
    parser.add_option("-e", "--expr", dest="expr",help="Expression data file")
    parser.add_option("-t", "--txt", dest="txt",help="Output file with the suboptimal parents sets")
    parser.add_option("-n", "--net", dest="net",help="Output file with the network structure")
    parser.add_option("-c", "--cpd", dest="cpd",help="Output file with the conditional probability distributions")
    parser.add_option("-b", "--bif", dest="bif",help="Output file with the optimal network in BIF format")
    parser.add_option("-r", "--roc", dest="roc",help="Output file with the ROC curves (PDF)")
    parser.add_option("-v", "--verbose", dest="verbose", action="store_true", default=False, help="Print comments")
    parser.add_option("-p", "--prior-pseudocount", dest="prior",help="Total pseudocount number in Dirichlet priors of BDE score (default: 1 pseudocount for each possible value vector)")
    parser.add_option("-i", "--suboptimal", dest="n_min", default=1, type=int,help="Number of computed suboptimal parents sets")
    parser.add_option("-m", "--min-empty", dest="min_empty", default=None, type=float,help="Reported suboptimal parents sets' minimal probability ratio to the empty set")
    parser.add_option("-o", "--min-optimal", dest="min_optim", default=None, type=float,help="Reported suboptimal parents sets' minimal probability ratio to the optimal set")
    parser.add_option("-l", "--limit", dest="limit",help="Limit of the size of parents subsets")
    parser.add_option("-f", "--fraction", dest="frac", type=float,default=None,help="Minimum weight of a parent to be considered in a parent set.")
    parser.add_option("-k", "--cross-val-folds", dest="k", type=int,default=10,help="Number of Cross Validation folds to consider (1 for no cross-validation.)")
    parser.add_option("-A", "--chi", dest="chi", type=float, default=.9999, help="alpha value for the chi-square distribution used in MIT score (default=.9999)")

    
    (options, args) = parser.parse_args()
    
    
    if options.expr==None:
        print "You must provide the training file name. Run bnf_cv --help for more options."
    else:
        if options.verbose:
            print 'Loading data from:', options.expr, '...',
        try:
            d =  dataset(options.expr).fromNewFile(open(options.expr))
        except KeyError:
            if options.verbose:
                print "Not a new file format. The old format is deprecated!"
            d = dataset(options.expr).fromFile(open(options.expr))

        if options.verbose:
            print 'done\n'

        if options.verbose:
            print 'Making %d cross-validation folds.\n'%options.k

        if options.k==1:
            folds=[(range(len(d.points)),range(len(d.points)))]
        else:
            from BNfinder import CrossVal
            rnd_state,folds=CrossVal.cv_folds(range(len(d.points)),k=options.k)

        preds_all=[]
        

        
        

        if options.roc:
            from BNfinder import Rocr
            Rocr.device(options.roc)
        for i,(train,test) in enumerate(folds):
            #if options.verbose:
            print 'Fold %d: \nLearning network '%i, d.name,' with', options.score, 'scoring function'
            d_train=copy.deepcopy(d)
            d_train.points=[d.points[j] for j in train]
            score=eval(options.score)(data_factor=options.df,prior=options.prior,chi_alpha=options.chi,sloops=False)
            score,g,subpars = d_train.learn(score=score,\
                            data_factor=options.df,\
                            prior=options.prior,\
                            verbose=options.verbose,\
                            n_min=options.n_min,\
                            limit=options.limit,\
                            min_empty=options.min_empty,\
                            min_optim=options.min_optim)
            cpd=d_train.to_cpd(g)
            d_test=copy.deepcopy(d)
            d_test.points=[d.points[j] for j in test]
            d_test.vertice_names=d_test.regulators[0]
            cl=StringIO.StringIO()
            d_test.classify(cpd,cl,prob=1)
            preds=read_preds(cl.getvalue(),d_test)
            preds_all.append(preds)
            #print cl.getvalue()
            if options.roc and options.k==1:
                for k,(pre,cls) in preds.items():
                    try:
                        Rocr.simpleROC(pre,cls,k+"fold %d"%i,options.yax,options.xax,diag=options.diag)
                    except:
                        print pre,cls
            if options.verbose:
                print 'Total score of optimal network:',score
                print g
            print g.to_SIF()

            if options.txt:
                d.write_txt(subpars,options.txt+str(i))
                if options.verbose:
                    print "Suboptimal parents sets written to:", options.txt+str(i)

            if options.net:
                if options.n_min==1:
                    net_str=g.to_SIF()
                else:
                    net_str=g.to_SIF(subpars)
                f=open(options.net+str(i),"w")
                f.write(net_str)
                f.close()
                #g.toDot().write_dot(options.net)
                if options.verbose:
                    print "Network graph written to:", options.net+str(i)
                    print net_str

            if options.cpd:
                f_cpd=open(options.cpd+str(i),"w")
                #d.write_cpd(g,f_cpd,options.prior)
                #f_cpd.write("XXXXXXXXXXXXXXXXX\n")
                #print "XXXXX",repr(options.frac)
                if options.frac:
                    g1=g.weighted_edges(subpars,float(options.frac))
                    util.pprint(f_cpd,d.to_cpd(g1,options.prior))
                else: #outpu the best network
                    util.pprint(f_cpd,d.to_cpd(g,options.prior))
                f_cpd.close()
                if options.verbose:
                    print "Conditional probability distributions written to:", options.cpd+str(i)

            if options.bif:
                d.write_bif(g,options.bif+str(i),options.prior,\
                    ['Network graph learned with %s scoring function from expression data file \'%s\'' %(options.score,options.expr)])
                if options.verbose:
                    print "Bayesian network written to:", options.bif+str(i)

        
        if options.roc:
            if options.k>1:
                for k in preds.keys():
                    ps,cs=[],[]
                    for d in preds_all:
                        p,c=d[k]
                        ps.append(p)
                        cs.append(c)

                    try:
                        Rocr.cvROC(ps,cs,k+" all folds",options.yax,options.xax,diag=options.diag)
                    except:
                        print ps,cs        
            Rocr.close()
