#!/usr/bin/env python
"""
Copyright (C) 2011 by Michael Sarahan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import utils.fread
import utils.msa_rpy as msa
import utils.db_model as m
from utils.db_model import *
from glob import glob
import os
import numpy as np
import sys
import Image

# Uncomment for debugging.  Also uncomment lines below that mention ipshell.
# from IPython.Shell import IPShellEmbed
# ipshell = IPShellEmbed()

def init(clear=True):
    if clear:
        os.system('rm ali/*')
        os.system('rm spots*')
        os.system('rm listparticles*')
    m.initializeDB(clear=clear,debug=False)

def run_crops(infiles):
    cwd=os.getcwd()
    # fake parent image - just the first cropped image.
    p=m.ParentImage(location=cwd,filename=infiles[0],data=np.asarray(Image.open(infiles[0])))
    # instantiate several other tables that will get added to later.
    pset=m.ParentSet(parents=[p],nparents=1)
    pcaset=m.PCA_Factors(pset=pset)
    icaset=m.ICA_Factors(pset=pset)
    for i in xrange(len(infiles)):
        # add a database entry for each image
        s=m.SubImage(x=i,y=i,number=i,parent=p,
                     data=np.asarray(Image.open(infiles[i]),dtype=np.float32))
        session.commit()
    subimages=m.SubImage.query.all()
    process(subimages,pset)
    
def run(infiles,template=None,crop=True,nparticles=500,nfactors=15,maskRadius=None,neighborhood=None):    
    cwd=os.getcwd()
    if infiles is 'all':
        parents=m.ParentImage.query.all()
        nparents=len(parents)
        pset=m.ParentSet.query.filter(m.ParentSet.parents.any(
             m.ParentImage.filename==parents[0].filename
             )).filter_by(nparents=nparents).all()
        if not pset:
            pset=m.ParentSet(parents=parents,nparents=nparents)
            pcaset=m.PCA_Factors(pset=pset)
            icaset=m.ICA_Factors(pset=pset)
        else: pset=pset[0]
        subimages=m.SubImage.query.all()
    elif not crop:
        for infile in infiles:
            p=m.ParentImage.query.filter_by(filename=infile).one()
            pset=m.ParentSet(parents=[p],nparents=1)
            pcaset=m.PCA_Factors(pset=pset)
            icaset=m.ICA_Factors(pset=pset)
            subimages=m.SubImage.query.filter_by(parent=p).all()
    else:
	import utils.SpiderControl as sc
        infiles.sort()
        for infile in infiles:
            if len(m.ParentImage.query.filter_by(filename=infile).all()) is 0:
                p=m.ParentImage(location=cwd,filename=infile,data=np.asarray(Image.open(infile)))
            else:
                p=m.ParentImage.query.filter_by(filename=infile).one()
            nparents=1
            pset=m.ParentSet.query.filter(m.ParentSet.parents.contains(p)).filter_by(nparents=nparents).all()
            if len(pset) < 1:
                pset=m.ParentSet(parents=[p],nparents=nparents)
                pcaset=m.PCA_Factors(pset=pset)
                icaset=m.ICA_Factors(pset=pset)
            sc.spiderCrop(template,infile,numpeaks=nparticles,maskRadius=maskRadius,neighborhood=neighborhood)
            sc.addSubImagesToDB(infile)
            #sc.spiderCorAn(infile)
            #sc.readFactors(infile)
            #sc.readIMC(infile)
            print "Spider CA done for %s - results loaded into database"%infile
            #p=m.ParentImage.query.filter_by(filename=infile).one()
            subimages=m.SubImage.query.filter_by(parent=p).all()
    avgimg=np.sum(np.array([subimage.data for subimage in subimages]),axis=0)/float(len(subimages))
    pset.avgImage=avgimg
    process(subimages,pset,nfactors)
    
def process(subimages,pset,nFactors=None):
    arraysize=subimages[0].data.size
    data=np.zeros((arraysize,len(subimages)))
    for i in xrange(len(subimages)):
        data[:,i]=subimages[i].data.reshape((-1,))    
    if nFactors > (data.shape[1]-1):
        nFactors=(data.shape[1]-1)
    #next line is not technically correct.  Scores & factors should be swapped.
    pca_scores,pca_factors,pca_eigvalues=msa.pca(data)
    n_evals=len(pca_eigvalues)
    if not nFactors:
        nFactors=msa.linfit(pca_eigvalues,n_evals-0.1*n_evals)
    ica_factors,ica_scores=msa.ica(data,nFactors)
    #or_factors,or_scores=msa.OrthoRotation(data, method)
    #ob_factors,or_scores=msa.ObliqueRotation(data, method)
    #ipshell()
    #ortho_scores,oblique_factors=msa.ica(data,nFactors)
    #oblique_scores,oblique_factors=msa.ica(data,nFactors)
    factorImage_dim=subimages[0].data.shape[0]
    #parentq=m.session.query(m.ParentImage)
    #print infile
    #parentq.filter_by(filename=infile)
    #print p
    #ipshell()
            
    pcaevs=[]
    icaevs=[]
    pcaset=m.PCA_Factors.query.filter_by(pset=pset).one()
    icaset=m.ICA_Factors.query.filter_by(pset=pset).one()
    #orset=
    #obset=
    #ipshell()
    ct=1
    
    for factor in xrange(nFactors):
        fac=pca_factors[:,factor].reshape((factorImage_dim,factorImage_dim))
        pcafac=m.PCA_Factor(factor=fac,vset=pcaset,number=ct)
        pcaevs.append(pcafac)
        fac=ica_factors[:,factor].reshape((factorImage_dim,factorImage_dim))
        icafac=m.ICA_Factor(factor=fac,vset=icaset,number=ct)
        icaevs.append(icafac)
        ct=ct+1
    m.session.commit()
    #print "PCA and ICA factors done for %s - results loaded into database"%infile
    '''
            pcaevquery=m.session.query(m.PCA_Factor)
            if parents.__class__ is list:
                for parent in parents:
                    pcaevquery.filter(m.PCA_Factor.parents.contains(parent))
                pcaevquery.filter_by(nparents=nparents)
                pcaevs=pcaevquery.all()
                icaevquery=m.session.query(m.ICA_Factor)
                for parent in parents:
                    icaevquery.filter(m.ICA_Factor.parents.contains(parent))
                icaevquery.filter_by(nparents=nparents)
                icaevs=icaevquery.all()
            else:
                pcaevs=pcaevquery.filter_by(parents=parents).all()
                icaevs=icaevquery.filter_by(parents=parents).all()
            ipshell()
    '''
    #ipshell()
    #pca_insert=m.PCA_Score.table.insert()
    #ica_insert=m.ICA_Score.table.insert()
    #pvaldicts=[]
    #ivaldicts=[]
    for subindex in xrange(len(subimages)):
        for number in xrange(len(pcaevs)):
            pcaev=pcaevs[number]
            m.PCA_Score(subimage=subimages[subindex],score=pca_scores[number,subindex],factor=pcaev)
            #pvaldicts.append({'subimage':subimages[subindex],'score':pca_scores[number,subindex],'factor':pcaev})
        for number in xrange(len(icaevs)):    
            icaev=icaevs[number]
            m.ICA_Score(subimage=subimages[subindex],score=ica_scores[number,subindex],factor=icaev)
            #ivaldicts.append({'subimage':subimages[subindex],'score':ica_scores[number,subindex],'factor':icaev})
        if not (subindex%50) or (subindex==(len(subimages)-1)):
            #pca_insert.execute(pvaldicts)
            #ica_insert.execute(ivaldicts)
            m.session.commit()
            #print pvaldicts[0]
            print "%i of %i sub-image scores loaded into database"%(subindex+1,len(subimages))
            #pvaldicts=[]
            #ivaldicts=[]
        
        #[[m.Ortho_Score(subimage=subimage,score=ortho_scores[x,y],number=numbers[x]) for y in xrange(nFactors)]for x in xrange(oblique_scores.shape[1])]
        #[[m.Oblique_Score(subimage=subimage,score=oblique_scores[x,y],number=numbers[x]) for y in xrange(nFactors)]for x in xrange(oblique_scores.shape[1])]
            
if __name__=='__main__':
    init(clear=True)
    infiles=glob('test_images/*')
    #For Hao - just a simple run on images that are already cropped.
    # Auto-determine number of factors.
    run_crops(infiles)
    """
    try:
        nparticles=int(sys.argv[1])
    except:
        print "Number of particles (argument 1) not specified or invalid.  Defaulting to 500."
        nparticles=500
    try:
        maskRadius=float(sys.argv[2])
    except:
        print "Mask radius (argument 2) not defined or invalid.  Defaulting to 0.8."
        maskRadius=0.8
    try:
        neighborhood=int(sys.argv[3])
    except:
        print "Neighbor exclusion distance (argument 3) not defined or invalid.  Defaulting to 1/10 of template size."
        neighborhood=None
    template=glob('template*.png')
    if len(template) > 1:
        print "More than one template found:"
        for tmp in template:
            print tmp
        print "Only one file starting with 'template' can be in this folder."
        sys.exit()
    else:
        template=template[0]
    print nparticles,maskRadius
    run(infiles,template,crop=True,nparticles=nparticles,maskRadius=maskRadius,neighborhood=neighborhood)
    """
    #infiles=glob('022*tif')
    #template='template2.png'
    #run(infiles,template)
    #run('all')
    
    
