# -*- coding: utf-8 -*-
def plot(gsn, gss, gsc='', gls='', pc=0.01, adj=True, prm=10000, gsz=10**6,
         deg=10**6, spl=2, sae=False, nsz=300, eds=5, lbs=9, hmt='binary', 
         cft='', ctt='', csp='', nwf='', hmf='', gml=''):
    """Kiwi: a tool to combine gene-set analyses with biological networks.
    
    Produces a network plot and a heatmap given a gene-set interaction network, 
    gene-set statistics, gene-set collection and gene-level statistics.
    
    Parameters
    ----------
    
    INPUT
    
    gsn : The geneset-geneset interaction file path. It is a 2 column table where 
          each row contains two interacting genesets, each one per column, with 
          no header. [required]
          
    gss : The geneset statistics result file path. It is, in its minimum 
          implementation, a 2 or more column table where each row must contain 
          as first column the genesets and in any column a statistics value for 
          its significance. The header must identify the first column as Name 
          and the geneset statistics column as p-value. If -ad is set to True, 
          adjusted p-value are expected to be found in a column identified by 
          the header as p-adj. Alternatively, a full geneset result file 
          generated using piano can be supplemented. If generated using the 
          function writeFilesForKiwi.R in piano, the file is parsed 
          automatically. [required]

    gsc : The geneset collection (geneset-gene) file path. It is a two column 
          table with genesets - genes associations in each row and no header.
          [default = '']

    gls : The gene-level statistics file path. It is a three column table with 
          gene - p-value - fold-change in each row. The genes must be as first 
          column, while p values and fold-changes must be identified in the 
          headers as p-value and FC. [default = '']
          
    NETWORK ANALYSIS PARAMETERS

    pc  : This flag controls the maximum threshold for the significance of a 
          geneset, after which it is discarded from the results. [default = 0.01]

    adj : This flag controls if p-values adjusted for multiple testing should be 
          used. [default = True]
          
    prm : This flag states the number of permutations used if geneset p-values 
          were calculated using permutation tests rather than null hypothesis 
          significance tests. This allows for better estimation of the node size 
          of genesets whose p values are exactly zero. [default = 10000]

    gsz : This flag control the maximum number of genes for a geneset after which 
          it is discarded from the results. It can be used to exclude high gene 
          count genesets, whose interpretability would be difficult and possibly 
          messy to display in the heatmap. [default = 10**6]

    deg : This flag controls the maximum degree of a geneset in the interaction 
          network after which it is discarded from the results. It can be used 
          to exclude very connected genesets (like ATP in a metabolic network) 
          from the plots. [default = 10**6]

    spl : This flag controls the maximum shortest path length between two genesets 
          in the interaction network after which no edge would connect them. This 
          can be interpreted as the threshold after which two genesets are 
          considered unrelated. [default = 2]
          
    sae : Show all edges. This flag controls if all but the best edges (defined
          by having the shortest path length) for each node should be removed 
          before plotting. If set to True, all edges (passing the spl cutoff)
          will be drawn. If set to False, only the best edges for each node will
          be included. [default = False]


    PLOTTING PARAMETERS

    nsz : This flag defines the node size of genesets with the highest p value 
          to be plotted (in the extreme scenario, this equals pCutoff).
          [default = 300]

    eds : This flag control the scaling of edge widths with one increment in the 
          shortest path length between two connected genesets. [default = 5]

    lbs : This flag sets the label sizes in the plots. [default = 9]

    hmt : This flag sets the color of the entries in the heatmap to black or 
          white when set to “binary” or to a blue-to-red colormap according to 
          the fold-change when set to “values”. [default = 'binary']

    cft : This flag reports the annotation source for the genes in the GSA run 
          (e.g. Ensembl). It should match the annotation sources listed in 
          mygene.info (e.g. ensembl.gene). [default = '']

    ctt : This flag reports the annotation desired to plot the gene names in the 
          heatmap (e.g. HUGO). It should match the annotation sources listed in 
          mygene.info (e.g. symbol).  [default = ''] :

    csp : This flag reports the species for the gene annotation (e.g. Homo 
          sapiens). It should match the species listed in mygene.info (e.g. 
          human). [default = '']

    EXPORT OPTIONS

    nwf : This flag defines the name of the file where the network plot is saved 
          (as PDF). If empty, it displays in the current device. [default = '']

    hmf : This flag defines the name of the file where the heatmap is saved (as 
          PDF). If empty, it displays in the current device. [default = '']

    gml : This flag defines the name of the file where the graph shown in the 
          network plot is saved (as GraphML).
    """
    # Import stuff: 
    import matplotlib
    matplotlib.use('PDF')
    import itertools
    import matplotlib.pyplot as plt
    import numpy as np
    import networkx as nx
    import os
    import classes as kiwiC
    import functions as kiwiF
    import warnings
    
    MNfile      = gsn
    GSfile      = gss
    GMfile      = gsc
    Gfile       = gls
    pcutoff     = float(pc)
    splcutoff   = int(spl)
    dcutoff     = int(deg)
    pzero       = 0.1/float(prm)
    minNodeSize = float(nsz)
    eScaleFac   = float(eds)
    labSize     = float(lbs)
    nwPlotFile  = nwf
    hmType      = hmt
    hmPlotFile  = hmf
    graphMLFile = gml
    if adj:
        adj = 'adj '
    else:
        adj = ''
    maxGSsize   = float(gsz)
    cgnFromType = cft
    cgnToType   = ctt
    cgnSpecies  = csp
    
    # Check for bad input stuff:
    if not os.access(MNfile,os.R_OK): raise IOError("Specified gene-set network file cannot be accessed")
    if not os.access(GSfile,os.R_OK): raise IOError("Specified gene-set statistics file cannot be accessed")
    if len(Gfile)==0: warnings.warn('No gene-level statistics files is provided: no heatmap will be generated',RuntimeWarning)
    if len(GMfile)==0: warnings.warn('No gene-geneset association file is provided: no lumping of overlapping gene-sets will be performed\n nor heatmap will be generated',RuntimeWarning)
    if not os.access(GMfile,os.R_OK) and len(GMfile) > 0: raise IOError("Specified gene-geneset association file cannot be accessed")
    if not os.access(Gfile,os.R_OK) and len(Gfile) > 0: raise IOError("Specified gene-level statistics file cannot be accessed")
    if pcutoff >= 1: raise NameError("pc must be lower than 1")
    if pcutoff <= pzero: warnings.warn('pc should be larger than the p-value resolution ('+str(pzero)+')',RuntimeWarning)
    if splcutoff < 0: warnings.warn('spl is negative and it has been set to 0','RuntimeWarning')
    if pzero >= 1: raise NameError("prm is too low")
    if minNodeSize <= 0: raise NameError("nsz must be a positive value")
    if eScaleFac <= 0: raise NameError("eds must be a positive value")
    if labSize <= 0: raise NameError("nsz must be a positive value")
    if minNodeSize <= 0: raise NameError("lbs must be a positive value")
    if not hmType in ["binary","values"]: raise NameError("hmt must be either 'binary' or 'values'")
    if maxGSsize <= 0: raise NameError("gsz must be a positive value")
    
    # A general comment: note that in the code below it is assumed that the
    # gene-set interaction network is a metabolic network and that gene-sets
    # are metabolites, hence the nomenclature. Of course the code still works
    # for the general case!
    
    # Make metabolic network MN:
    MN = nx.read_edgelist(MNfile,nodetype=str,delimiter='\t')
    
    # Initialize metabolome M:
    M = kiwiC.Metabolome()
    
    # Read the metabolites from stats file and add to M:
    content = np.genfromtxt(GSfile,dtype=None,delimiter='\t')
    header = content[0]
    if header[0]!='Name': raise NameError("Gene-set statistic file has invalid header: first column should be named 'Name'.")
    if 'p (non-dir.)' in header and 'p (mix.dir.up)' in header and 'p (mix.dir.dn)' in header and 'p (dist.dir.up)' in header and 'p (dist.dir.dn)' in header:
        if len(adj)>0 and 'p adj (non-dir.)' not in header: raise NameError("Gene-set statistics file has invalid header: one column should be named 'p adj (non-dir).' or pAdjusted should be set as 'False'")
        if len(adj)>0 and 'p adj (mix.dir.up)' not in header: raise NameError("Gene-set statistics file has invalid header: one column should be named 'p adj (mix.dir.up).' or pAdjusted should be set as 'False'")
        if len(adj)>0 and 'p adj (mix.dir.dn)' not in header: raise NameError("Gene-set statistics file has invalid header: one column should be named 'p adj (mix.dir.dn).' or pAdjusted should be set as 'False'")
        if len(adj)>0 and 'p adj (dist.dir.up)' not in header: raise NameError("Gene-set statistics file has invalid header: one column should be named 'p adj (dist.dir.up).' or pAdjusted should be set as 'False'")
        if len(adj)>0 and 'p adj (dist.dir.dn)' not in header: raise NameError("Gene-set statistics file has invalid header: one column should be named 'p adj (dist.dir.dn).' or pAdjusted should be set as 'False'")
    elif 'p (non-dir.)' in header:
        header[np.where(header=='p (non-dir.)')] = 'p-value'
        if len(adj)>0 and 'p adj (non-dir.)' not in header: raise NameError("Gene-set statistics file has invalid header: one column should be named 'p adj (non-dir).' or pAdjusted should be set as 'False'")
        if 'p adj (non-dir.)' in header:
            header[np.where(header=='p adj (non-dir.)')] = 'p-adj'
    elif 'p-value' in header:
        if len(adj)>0 and 'p-adj' not in header: raise NameError("Gene-set statistics file has invalid header: one column should be named 'p-adj' or pAdjusted should be set as 'False'")
    else:
        raise NameError("Gene-set statistics file has invalid header.")
    
    for stats in content[1:]:
        stats = tuple(stats)
        m = kiwiC.Metabolite(stats[0])
        m.addGeneSetStats(stats,header,adj)
        M.addMetabolite(m)
    if not any([m.pNonDirectional!=np.nan for m in M.metaboliteList]): raise NameError("Invalid data type for gene-set p-value statistic: all values are NaN")
        
    # Remove non-significant metabolites and metabolites not in MN and metabolites w high degree:
    M.removeNotSignificantMetabolites(pcutoff)    
    if len(M.metaboliteList)==0:
        raise NameError('No metabolites passed the pCutoff')
    M.removeMetabolitesNotInMetNet(MN)
    if len(M.metaboliteList)==0:
        raise NameError('No more metabolites from the gene-set statistics file are present in the metabolite-metabolite network')
    if dcutoff < max(nx.degree(MN).values()):
        M.removeHighDegreeMetabolites(MN,dcutoff)
    if len(M.metaboliteList)==0:
        raise NameError('No metabolites passed the dCutoff')
    
    # Import the gene-level statistics:
    G = kiwiC.Genome()
    if len(Gfile)>0:
        content = np.genfromtxt(Gfile,dtype=None,delimiter='\t')
        header = content[0]
        if ('p' not in header) or ('FC' not in header): raise NameError("Gene-level statistics file should contain 'p' and 'FC' as column headers")
        for glstat in content[1:]:
            glstat = tuple(glstat)
            g = kiwiC.Gene(glstat[0])
            p = float(glstat[np.where(header == 'p')[0]])
            FC = float(glstat[np.where(header == 'FC')[0]])
            g.addGeneStats(p,FC)
            G.addGene(g)
    
    if len(GMfile)>0: 
        # Import the gene-metabolite information:
        content      = np.genfromtxt(GMfile,dtype=None,delimiter='\t')
        metList      = np.copy(M.metaboliteList)
        #genenameList = []
        for met in metList:
            genenames       = np.unique(np.squeeze(np.asarray((content[np.where(content[:,0]==met.name)[0],1]))))
            ##The following code is commented due to removal of gene-set lumping...
            #genenamesAsList = [g for g in genenames]
            #if genenamesAsList in genenameList:
            #    M.removeMetabolite(met)
            #    genenameList.append([np.nan])
            #    metInd = np.where([gl == genenamesAsList for gl in genenameList])[0][0]
            #    kiwiF.updateLabel(metList[metInd])
            #else:
            #    genenameList.append(genenamesAsList)
            ##The code chunk below was originally part of the above else statement...
            # Check if some genes were excluded by the upstream GSA from a metabolite due to missing stats
            if len(genenames) > maxGSsize:
                M.removeMetabolite(met)
            else:
                for genename in genenames:
                    gene = G.getGene(genename)
                    if isinstance(gene,kiwiC.Gene):
                        met.addGene(gene)
                    else:
                        newGene = kiwiC.Gene(genename)
                        newGene.addGeneStats(p=np.nan,FC=np.nan)
                        G.addGene(newGene)
                        met.addGene(newGene) 

    # Construct a dense plot graph:
    PG = nx.Graph()
    PG.add_nodes_from(M.metaboliteList)
    PG.add_edges_from(itertools.combinations(PG.nodes(),2))
    
    # Calculate distance and add to edge property. Add edge weight
    for e in PG.edges():
        try:
            PG[e[0]][e[1]]['shortest_path_length'] = nx.shortest_path_length(MN,e[0].name,e[1].name)
        except nx.NetworkXNoPath:
            PG[e[0]][e[1]]['shortest_path_length'] = float('Inf')
        PG[e[0]][e[1]]['weight'] = eScaleFac/PG[e[0]][e[1]]['shortest_path_length']
                
    # Remove edges according to shortest path length:
    edges_to_remove = [e for e in PG.edges() if PG[e[0]][e[1]]['shortest_path_length']>splcutoff]
    PG.remove_edges_from(edges_to_remove)
    
    # Keep only the best edge/s for each node (if not argument sae=True):
    if not sae:
        print "Help! You are not using sae!!"
        all_edges_to_save = []
        for met in PG.nodes():
            minspl = float('Inf')
            for e in nx.edges(PG,met):
                minspl = min(minspl,PG[e[0]][e[1]]['shortest_path_length'])
            edges_to_save = [e for e in nx.edges(PG,met) if PG[e[0]][e[1]]['shortest_path_length'] == minspl]
            for e in edges_to_save:
                all_edges_to_save.append(e)
        PG.remove_edges_from([e for e in PG.edges() if e not in all_edges_to_save 
            and (e[1],e[0]) not in all_edges_to_save])
    
    # Get edge width for plotting:
    edge_width = kiwiF.getEdgeProperty(PG,'weight')
    
    # Get node attribute for plotting:
    p          = np.array([[node.pNonDirectional,node.pMixDirUp,node.pDistDirUp,
                            node.pMixDirDn,node.pDistDirDn,node.pValue] for node in PG.nodes()])
    p_stable   = p
    p_stable[np.isnan(p_stable)] = 1
    pzero = min(pzero, pcutoff, p_stable[p_stable!=0].min())
    p_stable = p_stable + pzero # Add a small number that is at most 
                                # as high as the smallest non-zero number in p_stable
    p_stable   = -np.log10(p_stable)
    color_code = (p_stable[:,5] + ((p_stable[:,1]*(p_stable[:,0]+p_stable[:,2]) - 
                    p_stable[:,3]*(p_stable[:,0]+p_stable[:,4]))) / 
                    (2*p_stable.max()**2))
    node_size  = minNodeSize*(p_stable[:,0]+np.log10(pcutoff))+minNodeSize
    
    # Assign plot attributes for a node as node attributes:
    k = 0
    for node in PG.nodes(): 
        PG.node[node]['directionalityScore'] = float(color_code[k])
        PG.node[node]['-log10p'] = float(p_stable[k,0])
        k = k+1
    
    # Plot network:
    fig_nw = plt.figure(figsize=(8,8))
    pos=nx.spring_layout(PG,iterations=50,scale=5)
    nx.draw(PG, pos, width=edge_width, node_size=node_size, node_color=color_code, cmap=plt.cm.RdBu_r,
            vmin=-abs(color_code).max(), vmax=abs(color_code).max(), with_labels=False)
    nx.draw_networkx_labels(PG,pos,dict([[n,n.label] for n in PG.nodes()]), font_size=labSize)
    if len(nwPlotFile)>0: 
        plt.savefig(nwPlotFile, bbox_inches='tight')
        plt.close(fig_nw)
    else:
        plt.show()
        
    # Heatmap:
    if len(GMfile)>0 and len(Gfile)>0:
        kiwiF.drawHeatmap(PG,hmType,hmPlotFile,pzero,cgnFromType,cgnToType,cgnSpecies)
            
    # Export to graphML:
    if len(graphMLFile) > 0:
        nx.write_graphml(PG,graphMLFile)