#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Plots for displaying multivariate data
# Dec 2009, MCS
# Updated May, June 2010, MCS, added 2 functions for plot interactivity: circ
# and com (center of mass).  Changed naming convention from eigenvectors and 
# eigenvalues to factors and scores.  Added scoreHistogram function for
# examining the histograms of scores.

"""
Copyright (C) 2011 by Michael Sarahan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import numpy as np
from matplotlib import mpl
import matplotlib.pyplot as plt
import EMMSA.utils.db_model as m
from sqlalchemy.ext.sqlsoup import SqlSoup
import sqlalchemy as sa
from sqlalchemy import and_
import Image
import sys
from xplotlayout import xplotlayout
from scipy.ndimage.measurements import center_of_mass as CofM

m.initializeDB()

class circ:
  '''
  Interactive circle drawing on plots.  For highlighting interesting areas.
  To draw these circles on any other plot, use the coordinates contained in
  this class' xdatalist and ydatalist variables.
  
  To use:
  Instantiate this class like so:
  c=circ(radius=2)
  Now, create a plot you want to put circles on.  Note that you should set the
  radius to whatever size you want your circle in terms of your plot's scale.
  imshow(mypicture)
  Now, connect the circle object to the click event:
  connect('button_press_event', c.mycall)
  OK, now click on your plot.  A white circle should appear on it.
  You can get the coordinates of all your clicks like this:
  c.xdatalist
  c.ydatalist
  '''
  def __init__(self,color='white',radius=2):
    self.event=None
    self.xdatalist = []
    self.ydatalist = []
    self.color=color
    self.radius=radius
  def mycall(self, event):
    self.event = event
    self.xdatalist.append(event.xdata)
    self.ydatalist.append(event.ydata)
    print('x = %s and y = %s'% (event.xdata,event.ydata))
    ax = gca()  # get current axis
    ax.hold(True) # overlay plots.
    cir=Circle((event.xdata,event.ydata),radius=self.radius,fill=False,linewidth=2,edgecolor=self.color)
    ax.add_patch(cir)
    draw()

class com:
  '''
  Interactively determine centers of mass from regions on plots.  For 
  measuring atomic column positions.
  '''
  def __init__(self,data,color='white',boxsize=10,scale=None):
    self.event=None
    self.color=color
    self.boxsize=boxsize
    self.data=data
    self.scale=scale
    self.centers=[]
  def mycall(self, event):
    self.event = event
    boxsize=self.boxsize
    center=CofM(self.data[(event.xdata-boxsize/2):(event.xdata+boxsize/2),(event.ydata-boxsize/2):(event.ydata+boxsize/2)])
    center=center[0]-boxsize/2,center[1]-boxsize/2
    center=event.xdata+center[0],event.ydata+center[1]
    ax = gca()
    ax.hold(True)
    area=Rectangle((event.xdata-boxsize/2.,event.ydata-boxsize/2.),width=boxsize,height=boxsize,fill=False,linewidth=2,edgecolor=self.color)
    ax.add_patch(area)
    draw()
    if self.scale:
      center=center[0]*self.scale,center[1]*self.scale
    print center
    self.centers.append(center)
    
def initializeDB(cstring):
  global db
  db=SqlSoup(cstring)

def reflectedTable(name):
  return sa.Table(name, metadata, autoload=True, autoload_with=metadata.bind)

def avgImage(setID,scale=None,units=None):
  pset=m.ParentSet.query.filter_by(id=setID).one()
  avg=pset.avgImage
  fig=plt.figure()
  if not scale:
    a=plt.imshow(avg)
    plt.xlabel('Pixels')
    plt.ylabel('Pixels')
  else:
    a=plt.imshow(avg,
      extent=[0,scale*avg.shape[0],0,scale*avg.shape[1]])
    plt.xlabel(units)
    plt.ylabel(units)
  a.set_title('Average Image')
  plt.colorbar()
  plt.hot()
  return fig

def getPCAfactors(factorIDs):
  data=[]
  for idx in factorIDs:
    vid=m.PCA_Factor.query.filter_by(id=idx).one()
    data.append(vid.factor)
  return data

def getPCAscores(factorID,threshold=None,greater=True):
  if threshold:
    if greater:
      sresult=sa.select([m.PCA_Score.table.c.score],
            and_("pcaevec.id=%i"%factorID,"model_pca_score.score>%i"%threshold),
            from_obj=sa.join(m.PCA_Factor.table,
            m.PCA_Score.table))
    else:
      sresult=sa.select([m.PCA_Score.table.c.score],
            and_("pcaevec.id=%i"%factorID,"model_pca_score.score<%i"%threshold),
            from_obj=sa.join(m.PCA_Factor.table,
            m.PCA_Score.table))
  else:
    sresult=sa.select([m.PCA_Score.table.c.score],
            "pcaevec.id=%i"%factorID,
            from_obj=sa.join(m.PCA_Factor.table,
            m.PCA_score.table))
  
  return np.array(sresult.execute().fetchall()).astype('float')

def getICAfactors(factorIDs):
  data=[]
  for idx in factorIDs:
    vid=m.ICA_Factor.query.filter_by(id=idx).one()
    data.append(vid.factor)
  return data

def getFactorIDs(parentSetID,vectorIDs):
  """Vector IDs are only the relative order."""
  pset=m.ParentSet.query.filter_by(id=parentSetID).one()
  vectors=m.PCA_Factor.query.filter_by(vset=pset).all()
  vectorIDs=[vectorID-1 for vectorID in vectorIDs]
  factorIDs=[vectors[idx].id for idx in vectorIDs]
  return factorIDs

def getICAscores(factorID,threshold=None,greater=True):
  if threshold:
    if greater:
      sresult=sa.select([m.ICA_Score.table.c.score],
            and_("icaevec.id=%i"%factorID,"model_ica_score.score>%i"%threshold),
            from_obj=sa.join(m.ICA_Factor.table,
            m.ICA_Score.table))
    else:
      sresult=sa.select([m.ICA_Score.table.c.score],
            and_("icaevec.id=%i"%factorID,"model_ica_score.score<%i"%threshold),
            from_obj=sa.join(m.ICA_Factor.table,
            m.ICA_Score.table))
  else:
    sresult=sa.select([m.ICA_Score.table.c.score],
            "icaevec.id=%i"%factorID,
            from_obj=sa.join(m.ICA_Factor.table,
            m.ICA_Score.table))
  
  return np.array(sresult.execute().fetchall()).astype('float')

def getCoordsAndRats(parentfname=None,threshold=None,greater=True):
  itable=reflectedTable('parentset_parents__model_parentimage_parentset')
  if not parentfname:
    jtable=sa.join(m.SubImage.table,m.ParentImage.table)
    if threshold:
      if greater:
        sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                          m.SubImage.table.c.rat,m.SubImage.table.c.ratstd],
                          m.SubImage.table.c.rat>threshold,
                          from_obj=jtable)
      else:
        sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                          m.SubImage.table.c.rat,m.SubImage.table.c.ratstd],
                          m.SubImage.table.c.rat<threshold,
                          from_obj=jtable)    
    else:        
      sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                          m.SubImage.table.c.rat,m.SubImage.table.c.ratstd],
                          from_obj=jtable)

  else:
    jtable=sa.join(m.SubImage.table,ParentImage.table)
    if threshold:
      if greater:
        sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                          m.SubImage.table.c.rat,m.SubImage.table.c.ratstd],
                          and_(m.ParentImage.filename==parentfname,
                          m.SubImage.table.c.rat>threshold),
                          from_obj=jtable)
      else:
        sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                          m.SubImage.table.c.rat,m.SubImage.table.c.ratstd],
                          and_(m.ParentImage.filename==parentfname,
                          m.SubImage.table.c.rat>threshold),
                          from_obj=jtable)    
    else:        
      sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.SubImage.table.c.rat,m.SubImage.table.c.ratstd],
                            m.ParentImage.filename==parentfname,
                            from_obj=jtable)
  return np.array(sresult.execute().fetchall()).astype("float")

def getCoordsAndScores(MDAtype,factorID,parentfname=None,threshold=None,greater=True,rat=False):
  itable=reflectedTable('parentset_parents__model_parentimage_parentset')
  if not parentfname:
    if MDAtype is 'PCA':
      jtable=sa.join(m.PCA_Factor.table,m.PCA_Score.table
              ).join(m.SubImage.table).join(m.ParentImage.table)
      if threshold:
        if greater:
          sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.PCA_Score.table.c.score],
              and_("pcaevec.id=%i"%factorID,
                   "model_pca_score.score>%i"%threshold),
              from_obj=jtable)
        else:
          sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.PCA_Score.table.c.score],
              and_("pcaevec.id=%i"%factorID,
                   "model_pca_score.score<%i"%threshold),
              from_obj=jtable)
      else:        
        sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.PCA_Score.table.c.score],
              and_("pcaevec.id=%i"%factorID),
                            from_obj=jtable)
    elif MDAtype is 'ICA':
      jtable=sa.join(m.ICA_Factor.table,m.ICA_Score.table
              ).join(m.SubImage.table).join(m.ParentImage.table)
      if threshold:
        if greater:
          sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.ICA_Score.table.c.score],
              and_("icaevec.id=%i"%factorID,
                   "model_ica_score.score>%i"%threshold),
              from_obj=jtable)
        else:
          sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.ICA_Score.table.c.score],
              and_("icaevec.id=%i"%factorID,
                   "model_ica_score.score<%i"%threshold),
              from_obj=jtable)    
      else:        
        sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.ICA_Score.table.c.score],
                            "icaevec.id=%i"%factorID,
                            from_obj=jtable)

  else:
    if MDAtype is 'PCA':
      jtable=sa.join(m.PCA_Factor.table,m.PCA_Score.table
              ).join(m.SubImage.table).join(m.ParentImage.table)
      if threshold:
        if greater:
          sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.PCA_Score.table.c.score],
              and_(m.ParentImage.filename==parentfname,
                   "pcaevec.id=%i"%factorID,
                   "model_pca_score.score>%i"%threshold),
              from_obj=jtable)
        else:
          sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.PCA_Score.table.c.score],
              and_(m.ParentImage.filename==parentfname,
                   "pcaevec.id=%i"%factorID,
                   "model_pca_score.score<%i"%threshold),
              from_obj=jtable)
      else:        
        sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.PCA_Score.table.c.score],
              and_(m.ParentImage.filename==parentfname,
                   "pcaevec.id=%i"%factorID),
                            from_obj=jtable)
    elif MDAtype is 'ICA':
      jtable=sa.join(m.ICA_Factor.table, m.ICA_Score.table
              ).join(m.SubImage.table).join(m.ParentImage.table)
      if threshold:
        if greater:
          sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.ICA_Score.table.c.score],
              and_(m.ParentImage.filename==parentfname,
                   "icaevec.id=%i"%factorID,
                   "model_ica_score.score>%i"%threshold),
              from_obj=jtable)
        else:
          sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.ICA_Score.table.c.score],
              and_(m.ParentImage.filename==parentfname,
                   "icaevec.id=%i"%factorID,
                   "model_ica_score.score<%i"%threshold),
              from_obj=jtable)    
      else:        
        sresult=sa.select([m.SubImage.table.c.x, m.SubImage.table.c.y, 
                            m.ICA_Score.table.c.score],
               and_(m.ParentImage.filename==parentfname,
                    "icaevec.id=%i"%factorID),
                            from_obj=jtable)
  return np.array(sresult.execute().fetchall()).astype("float")

#def drawcirc(fig,color='white'):
  #connect(

def scoreHistogram(MDAtype,factorID,bins=30):
  if MDAtype is 'PCA':
    scores=getPCAscores(factorID)
  elif MDAtype is 'ICA':
    scores=getICAscores(factorID)
  else:
    print 'Not recognized MDAtype. Only ICA and PCA are currently implemented.'
    scores=None
  if scores:
    f=figure()
    plt.hist(scores,bins)
    plt.xlabel('Score')
    plt.ylabel('Counts')
    plt.title('Factor %i'%factorID)
    return f

def factorImagePlot(mdaType,factors,nperRow=3,scale=None,units=None):
  #w,h=matplotlib.figure.figaspect(len(factors)/2)
  #fig=plt.figure(figsize=(w,h))
  print len(factors)
  fsize=np.ceil(len(factors)/float(nperRow))
  fig=plt.figure(figsize=(4*nperRow,3.75*fsize))
  plt.subplots_adjust(left=0.1,right=0.9,hspace=0.46,wspace=0.3)
  numFacs = len(factors)
  rows=fsize
  if mdaType is 'PCA':
    data=getPCAfactors(factors)
    plt.suptitle('PCA Factors')
  elif mdaType is 'ICA':
    data=getICAfactors(factors)
    plt.suptitle('ICA Factors')
  for fac in range(numFacs):
    upperPlotPosition=(rows,nperRow,fac+1)
    a=fig.add_subplot(*upperPlotPosition)
    # cut out desired factor (column)
    factor=data[fac]
    # plot it
    if not scale:
      plt.imshow(factor)
      plt.xlabel('Pixels')
      plt.ylabel('Pixels')
    else:
      plt.imshow(factor,
        extent=[0,scale*factor.shape[0],0,scale*factor.shape[1]])
      plt.xlabel(units)
      plt.ylabel(units)
    a.set_title('Factor %s'%factors[fac])
    plt.colorbar()
    plt.hot()
  return fig

def ratioPlot(MDAType,feature1,features):
  """
  Plots the ratio of one factor image to one or more other factor images.
  If plotting more than one factor image, it gives the average ratio
  and the standard deviation.  Not presently functional.
  """

  if MDAtype is 'PCA':
    vec=m.PCA_Factor.query.filter_by(id=factorID).one()
  elif MDAtype is 'ICA':
    vec=m.ICA_Factor.query.filter_by(id=factorID).one()
  f1vals=getCoordsAndScores(MDAtype,feature1)
  fvals=[]
  for feature in features:
    fvals.append(getCoordsAndScores(MDAtype,factorID))

def importanceImage(MDAtype,factorID,scale=None,units=None):
  """
  Plots 3 images: the factor image, and then the average image plus and minus
  the maximum and minimum score scores times the factor image.  Often useful
  for trying to see what part of your structure is changing with a particular
  factor.

  Usage:
  MDAtype is either 'PCA' or 'ICA'
  factorID is the number of the factor that you want to plot.  Counting starts
  at 0 here.
  scale and units allow you to set the axes in real units, rather than pixels.
    Give the scale as the number of units per pixel, and give the units as
    a string.  If you input scale, you must input units.
  """
  if MDAtype is 'PCA':
    vec=m.PCA_Factor.query.filter_by(id=factorID).one()
  elif MDAtype is 'ICA':
    vec=m.ICA_Factor.query.filter_by(id=factorID).one()
  avgImage=m.ParentSet.query.filter_by(id=vec.parentset_id).one().avgImage
  scores=m.getCoordsAndScores(MDAtype,factorID)
  print scores.shape
  valmin=np.min(scores[:,2]);valmax=np.max(scores[:,2])
  f=plt.figure(figsize=(18,6))
  plt.subplots_adjust(left=0.1,right=0.9)
  cmap=mpl.cm.hot
  a=f.add_subplot(1,3,1)
  vdata=vec.factor
  if not scale:
    plt.imshow(vdata)
    plt.xlabel('Pixels')
    plt.ylabel('Pixels')
  else:
    plt.imshow(vdata,
            extent=[0,scale*vdata.shape[0],0,scale*vdata.shape[1]])
    plt.xlabel(units)
    plt.ylabel(units)
  plt.hot()

  a.set_title('Factor image number %i'%factorID)
  plt.colorbar()
  a=f.add_subplot(1,3,2)
  if not scale:
    plt.imshow(avgImage+valmin*vdata)
    plt.xlabel('Pixels')
    plt.ylabel('Pixels')
  else:
    plt.imshow(avgImage+valmin*vdata,  
      extent=[0,scale*vdata.shape[0],0,scale*vdata.shape[1]])
    plt.xlabel(units)
    plt.ylabel(units)
  plt.hot()
  a.set_title('Lowest score: %f'%valmin)
  a=f.add_subplot(1,3,3)
  if not scale:
    plt.imshow(avgImage+valmax*vdata)
    plt.xlabel('Pixels')
    plt.ylabel('Pixels')
  else:
    plt.imshow(avgImage+valmax*vdata,
      extent=[0,scale*vdata.shape[0],0,scale*vdata.shape[1]])
    plt.xlabel(units)
    plt.ylabel(units)
  plt.hot()
  a.set_title('Highest score: %f'%valmax)
  return f

def spatialRatPlot(parentfnames=None,threshold=False,
                greater=True,scale=None,units=None):
  parents=m.ParentImage.query.all()
  
  if not parentfnames:
    parentfnames=[p.filename for p in parents]
    parentfnames.sort()
  
  f=plt.figure(figsize=(4*len(parentfnames)+4,6))
  
  imaxes=xplotlayout(f,len(parentfnames),bottom=0.23,xstart=0.1,xstop=0.9,spacing=0.4/(len(parentfnames)))
  cbardim=0.9
  cbaraxes=xplotlayout(f,1,xstart=0.04,xstop=0.975,bottom=0.1,top=0.125,spacing=0.1,plotw=[cbardim])
  
  pfnames=[parents[pid].filename for pid in xrange(len(parents))]
  scores=getCoordsAndRats(None,threshold,greater)
  scores[:,2]=.5185788*np.log((scores[:,2]+.27102934)/.30967874)
  valmin=np.min(scores[:,2]);valmax=np.max(scores[:,2])
  print "Average occupation: ",np.average(scores[:,2])
  valcb=cbaraxes[0]
  valnorm=mpl.colors.Normalize(vmin=valmin,vmax=valmax)
  cmap=mpl.cm.hot
  cb2 = mpl.colorbar.ColorbarBase(valcb, cmap=cmap,
                                   norm=valnorm,
                                   orientation='horizontal')
  cb2.set_label('6d site occupation')
  
  pids=[]
  [[pids.append(i) for i,fname in enumerate(pfnames) if fname == f] for f in parentfnames]
  for ct in xrange(len(pids)):
    a=plt.axes(imaxes[ct])
    parent=parents[pids[ct]]
    if parent.filename in parentfnames:     
      scores=getCoordsAndRats(parent.filename,threshold,greater)
      scores[:,2]=.5185788*np.log((scores[:,2]+.27102934)/.30967874)
      if not scale:
        plt.imshow(parent.data)
      else:
        w,h=parent.data.shape
        print w,h
        plt.imshow(parent.data,extent=[0,w*scale,0,h*scale])
      plt.gray()
      pname=parent.filename
      if len(pname)>24: pname=pname[:3]+'...'+pname[21:30]+'...';
      a.set_title('%s'%pname)
      if not scale:
        scplot=plt.scatter(scores[:,0],scores[:,1],c=scores[:,2])
        plt.xlabel('Pixels')
        plt.ylabel('Pixels')
      else:
        scores[:,0]=scores[:,0]*scale
        scores[:,1]=scores[:,1]*scale
        scplot=plt.scatter(scores[:,0],scores[:,1],c=scores[:,2])
        plt.xlabel(units)
        plt.ylabel(units)
      plt.hot()
      scplot.set_clim(valmin,valmax)

def spatialPlot(MDAtype,factorID,parentfnames=None,threshold=False,
                greater=True,scale=None,units=None,rat=False):
  if MDAtype is 'PCA':
    vec=m.PCA_Factor.query.filter_by(id=factorID).one()
  elif MDAtype is 'ICA':
    vec=m.ICA_Factor.query.filter_by(id=factorID).one()
  parents=sorted(m.ParentSet.query.filter_by(id=vec.parentset_id).one().parents,key=ParentImage.getfilename)
  
  if not parentfnames:
    parentfnames=[parent.filename for parent in parents]
    parentfnames.sort()
  
  f=plt.figure(figsize=(4*len(parentfnames)+4,6))
  
  imaxes=xplotlayout(f,len(parentfnames)+1,bottom=0.2,xstart=0.1,xstop=0.9,spacing=0.4/(len(parentfnames)+1))
  cbardim=2./(len(parentfnames)+1)
  cbaraxes=xplotlayout(f,2,xstart=0.04,xstop=0.975,bottom=0.1,top=0.15,spacing=0.1,plotw=[cbardim,2-cbardim])
  
  cmap=mpl.cm.hot
  a=plt.axes(imaxes[0])
  vdata=vec.factor
  if not scale:
    plt.imshow(vdata)
    plt.xlabel('Pixels')
    plt.ylabel('Pixels')
  else:
    plt.imshow(vdata,
          extent=[0,scale*vdata.shape[0],0,scale*vdata.shape[1]])
    plt.xlabel(units)
    plt.ylabel(units)
  plt.hot()
  a.set_title('Factor %i'%factorID)
  pfnames=[parents[pid].filename for pid in xrange(len(parents))]
  
  scores=getCoordsAndScores(MDAtype,factorID,None,threshold,greater)
  print scores.shape[0]
  valmin=np.min(scores[:,2]);valmax=np.max(scores[:,2])
  
  veccb=cbaraxes[0]
  vecmin=np.min(vec.factor)
  vecmax=np.max(vec.factor)
  vecnorm=mpl.colors.Normalize(vmin=vecmin,vmax=vecmax)
  cb1 = mpl.colorbar.ColorbarBase(veccb, cmap=cmap,
                                   norm=vecnorm,
                                   orientation='horizontal')
  cb1.set_label('Factor intensity')

  valcb=cbaraxes[1]
  valnorm=mpl.colors.Normalize(vmin=valmin,vmax=valmax)
  cb2 = mpl.colorbar.ColorbarBase(valcb, cmap=cmap,
                                   norm=valnorm,
                                   orientation='horizontal')
  cb2.set_label('Scores')
  
  pids=[]
  [[pids.append(i) for i,fname in enumerate(pfnames) if fname == f] for f in parentfnames]
  for ct in xrange(len(pids)):
    a=plt.axes(imaxes[ct+1])
    parent=parents[pids[ct]]
    if parent.filename in parentfnames:
      if rat:
        scores=getCoordsAndRats(MDAtype,parent.filename,threshold,greater)      
      else:
        scores=getCoordsAndScores(MDAtype,factorID,parent.filename,threshold,greater)
      if not scale:
        plt.imshow(parent.data)
      else:
        w,h=parent.data.shape
        print w,h
        plt.imshow(parent.data,extent=[0,w*scale,0,h*scale])
      plt.gray()
      pname=parent.filename
      if len(pname)>24: pname=pname[:3]+'...'+pname[21:30]+'...';
      a.set_title('%s'%pname)
      if not scale:
        scplot=plt.scatter(scores[:,0],scores[:,1],c=scores[:,2])
        plt.xlabel('Pixels')
        plt.ylabel('Pixels')
      else:
        scores[:,0]=scores[:,0]*scale
        scores[:,1]=scores[:,1]*scale
        scplot=plt.scatter(scores[:,0],scores[:,1],c=scores[:,2])
        plt.xlabel(units)
        plt.ylabel(units)
      plt.hot()
      scplot.set_clim(valmin,valmax)
  
'''
def scoreThresholdPlot(MDAtype,factorID,threshold,greater=True):
  """
  This function figures out which subimages match your threshold criterion, and 
  pastes them onto a blank window that is the same size as the original image 
  from which the subimages were cropped.
  """
  if MDAtype is 'PCA':
    scores=getPCAscores(factorID,threshold,greater)
  elif MDAtype is 'ICA':
    scores=getICAscores(factorID,threshold,greater)
  pasties={}  ###########
  c=0
  for score in scores:
    pimagename=score.subimage.parent.filename
    if pimagename not in pasties.keys():
      #duplicate array into RGB for image pasting
      pimagedata=score.subimage.parent.data
      pasties['%s'%pimagename]=np.zeros(pimagedata.shape,dtype="uint8")
    simage=score.subimage.data
    sx=score.subimage.x
    sy=score.subimage.y
    print c
    c=c+1    
    pasties['%s'%pimagename][sx:sx+simage.shape[0],sy:sy+simage.shape[1]]=simage
  for fname,data in pasties.items(): #enumerate dictionary
    image=Image.fromarray(data)
    if greater:
      image.save('%s_scoreID%03i_gt%04i.png'%(fname,factorID,threshold))
    else:
      image.save('%s_scoreID%03i_lt%04i.png'%(fname,factorID,threshold))
    
def factorPlot(rotationNode,factors):
  """rotation node should be a rotated node instance (from the RotationNode
  module of MDP).  Can be either an OrthoRotationNode or ObliqueRotationNode.
  factors should be a list of the factors that you want to show.
  This plots the original factor on top of the rotated one, with a column
  for each factor.
  """
  #w,h=matplotlib.figure.figaspect(len(factors)/2)
  #fig=plt.figure(figsize=(w,h))
  fig=plt.figure()
  
  
  numEigs = len(factors)
  for eig in range(numEigs):
    upperPlotPosition=(2,numEigs,eig+1)
    a=fig.add_subplot(*upperPlotPosition)
    a.set_axis_off()
    # cut out desired factor (column)
    data=rotationNode.unrotatedArray[:,factors[eig]]
    # reshape it to its original square shape
    data=data.reshape((np.sqrt(len(data)),-1))
    # plot it
    plt.imshow(data)
    a.set_title('Eigenimg %s'%factors[eig])
    plt.colorbar()
    
    lowerPlotPosition=(2,numEigs,eig+numEigs+1)
    b=fig.add_subplot(*lowerPlotPosition)
    b.set_axis_off()
    # cut out desired factor (column)
    data=rotationNode.v[:,factors[eig]]
    # reshape it to its original square shape
    data=data.reshape((np.sqrt(len(data)),-1))
    # plot it
    plt.imshow(data)
    b.set_title('Eigenimg %s'%factors[eig])
    plt.colorbar()
  return fig

def scoreScorePlot(data,pcanode,rotnode,factors):
  """rotation node should be a rotated node instance (from the RotationNode
  module of MDP).  Can be either an OrthoRotationNode or ObliqueRotationNode.
  factors should be a tuple with two of the factors that you want to compare.
  This plots the scores for each of the original images, in terms of the
  original PCA-derived space, and in the rotated space.
  """
  pcaScores=pcanode(data)
  pcaScores=pcaScores-pcaScores.min(axis=0)
  pcaScores=pcaScores/pcaScores.max(axis=0)
  rotScores=rotnode(data)
  rotScores=rotScores-rotScores.min(axis=0)
  rotScores=rotScores/rotScores.max(axis=0)
  fig=plt.figure()
  unrotated=plt.plot(pcaScores[:,0],pcaScores[:,1],'ro')
  rotated=plt.plot(rotScores[:,0],rotScores[:,1],'bs')
  plt.figlegend((unrotated,rotated),('Original PCA scores','Rotated scores'),'upper right')
  return fig
  
'''
