# -*- coding: utf-8 -*-
"""
Copyright (C) 2011 by Michael Sarahan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import matplotlib.pyplot as plt
import numpy as np
import pickle

class Project(object):
    """
    The topmost data structure.  Contains instances for each of the later bits
    of analysis, and a record of what analysis has been done.  Implements 
    methods for saving and loading data sets.
    """
    def __init__(self,name='New Project',datatype=None):
        self.name=name
        self.datatype=datatype
    
    def save(self,filename):
        pickle.dump(self,filename)
        
    def load(self,filename):
        self=pickle.load(filename)
        return self
        

class Data(object):
    """
    Base class for any sort of data.  dims variable represents the dimensions
    relevant to the data.  This is handled differently by each subclass.  The
    methods of this generic class are similar to those for the spectra class.
    I'm assuming that most data is something like signal vs time - as long as
    that is true, this generic representation works well.
    """
    def __init__(self,data=None,dims=None,cal={}):
        self.data=data
        self.dims=dims
        self.cal=cal
        
    def show(self,IDs,overlay=True):
        # Did the user pass in a list of IDs?
        if type(IDs).__name__=='list':
            # Should the spectra be overlaid on one plot?  By default, yes.
            if not overlay:
                # If not overlay, how many rows do we need to lay out all the
                # individual plots?
                rows=1+np.ceil(len(IDs)/3.)
                fig=plt.figure(figsize=(8,3*rows))
                # Make a new subplot for each spectrum
                for i in xrange(len(IDs)):
                    sp=fig.add_subplot(rows,3,i+1)
                    self.specPlot(IDs[i],label=IDs[i],subplot=sp)
            else:
                # Plot all spectra overlaid on one plot.
                for i in xrange(len(IDs)):
                    self.specPlot(IDs[i],label=IDs[i])
        else:
            # If only given one ID, plot that one spectrum
            self.specPlot(IDs,label=IDs)
        
    def read(self,f,offset):
        self.data=np.loadtxt(f,skiprows=offset)
                                
    def related(self,ID=None,loc=None,d=5):
        """
        Find related factors, either by ID or by proximity to a user-defined
        point on the currently plotted factor.  d is the distance in points
        from the user-defined point to define the neighborhood to search in.
        
        NOT PRESENTLY WORKING!!!
        """
        if loc:
            if self.dims:
                # Is this an image?  If so, look in the 2D region nearby.
                nb=self.data[loc[0]-d:loc[0]+d,loc[1]-d:loc[1]+d]
            else:
                nb=self.data[loc-d:loc+d]
            alike=None
        else:
            # Find relatedness by examining correlation coefficients
            pass
            alike=None # top n most correlated factors
        return alike
        
    def specPlot(self,ID,label=None,fig=None,subplot=None):
        """
        Basic plots for spectra.        
        """
        if fig:
            plt.figure(fig.number())
        else:
            fig=plt.gcf()
        if subplot:
            fig.add_subplot(subplot)
        try:
            cal=self.cal['energy']
            caldata=self.data[:,0]*cal['value']
        except:
            # energy calibration is not defined.
            cal = {'units':None}
            caldata=self.data[:,0]
        plt.plot(caldata,self.data[:,ID],label=label)
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        plt.xlabel('%s'%cal['units'],fontsize=24)
        plt.ylabel('Arbitrary Units',fontsize=24)
        plt.legend()

class Images(Data):
    def read(self,flist):
        try:
            import cv
            def cv2array(im):
                depth2dtype = {
                    cv.IPL_DEPTH_8U: 'uint8',
                    cv.IPL_DEPTH_8S: 'int8',
                    cv.IPL_DEPTH_16U: 'uint16',
                    cv.IPL_DEPTH_16S: 'int16',
                    cv.IPL_DEPTH_32S: 'int32',
                    cv.IPL_DEPTH_32F: 'float32',
                    cv.IPL_DEPTH_64F: 'float64',
                    }
                a = np.fromstring(
                im.tostring(),
                dtype=depth2dtype[im.depth],
                count=im.width*im.height*im.nChannels)
                a.shape = (im.height,im.width,im.nChannels)
                return a        
                
            imsample=cv.LoadImage(flist[0])
            colLength=imsample.width*imsample.height
            PIL=False
            
        except:
            import Image
            imsample=Image.open(flist[0])
            colLength=imsample.size[0]*imsample.size[1]
            PIL=True
        data_array=np.zeros((colLength,len(flist)))
        i=0
        ctr=xrange(colLength)

        if PIL:
            for image in flist:
                data=np.array(Image.open(image)).reshape(colLength)
                for ct in ctr:
                    data_array[ct,i]=data[ct]
                i+=1
        else:
            for image in flist:
                data=np.array(cv2array(cv.LoadImage(image,0))).reshape(colLength)
                for ct in ctr:
                    data_array[ct,i]=data[ct]
                i+=1       

        self.data=data_array
    
    def show(self,IDs,nPerRow=3):
        # Did the user pass in a list of IDs?
        if type(IDs).__name__=='list':
            # Show a small figure for each image in the list
            rows=np.ceil(len(IDs)/float(nPerRow))
            fig=plt.figure(figsize=(4*nPerRow,3*rows))
            # Make a new subplot for each spectrum
            for i in xrange(len(IDs)):
                sp=fig.add_subplot(rows,3,i+1)
                self.imgPlot(IDs[i],label=IDs[i],subplot=sp)

        else:
            # If only given one ID, plot that one image
            self.imgPlot(IDs,label=IDs)        

    def _getExtent(self):
        """
        Returns the calibrated image coordinates, if calibration exists.
        """
        try:
            self.cal['space']
            factor=self.cal['space']['value']
            return (0,self.dims[0]*factor,0,self.dims[1]*factor)
        except:
            return None

    def imgPlot(self,ID,label=None,fig=None,subplot=None):
        """
        Plots a single image.  Adds calibration data if the cal property of the
        object exists.  The cal property should be a dictionary with a 'space'
        key.  The value for that key should be another dictionary with two keys:
            
        'value', which is the number of units per pixel. (Int or float)
        'units', which is the units with which axes are labeled. (String)
        """
        # Allow user to specify a figure, if the front one is not the one to
        # plot to.
        if fig:
            plt.figure(fig.number())
        # Or, just do the front one.  This creates a new figure if one doesn't
        # exist.
        else:
            fig=plt.gcf()
        
        # Allow user to specify a subplot on the current figure.  It can either
        # be a subplot instance, or it can be the 3-digit numeric specification
        # Note that although the numeric representations 333 and 3,3,3 are
        # equivalent, specifying one, then the other will clear that subplot
        # and start from a clean slate!
        if subplot:
            fig.add_subplot(subplot)
            
        # Check if we have a spatial calibration for the image
        try:
            cal=self.cal['space']
        except:
            # energy calibration is not defined.
            cal = {'units':'Pixels'}
            
        # If image label is not given, just title it by its ID number.
        if not label: label = ID
        
        # Plot the data.  Changing interpolation to either 'bilinear' or
        # 'bicubic' can make it look smoother.
        plt.imshow(self.data[:,ID].reshape(self.dims),
                   interpolation='nearest',extent=self._getExtent())
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        plt.xlabel(cal['units'],fontsize=18)
        plt.ylabel(cal['units'],fontsize=18)
        plt.title('%s'%label,fontsize=24)
        
        # return the figure, now with our image plot on it.
        return fig

class Spectra(Data):
    def read(self,flist,offset=None):
        if len(flist)>1:
            # make array of shape (spectrum length, nSpectra + 1)
            ex=np.loadtxt(flist[0],skiprows=offset)
            res=np.zeros(ex.shape[0],len(flist)+1)
            # read the first spectrum and set the first two columns of res array
            res[:,0]=ex[:,0]
            # read only the second column of each spectrum - the first is energy
            for s in xrange(len(flist)):
                res[:,s+1]=np.loadtxt(flist[s],skiprows=offset)[:,1]
            self.data=res
        else:
            # read the file - it is a csv or equivalent with one spectrum per
            # column.
            self.data=np.loadtxt(flist,skiprows=offset)
        
class SpectrumImage(Images, Spectra): 
    def read(self,filename):
        f=open(filename,'r')
        self.data = np.loadtxt(filename,skiprows=3)
        dims=f.readlines()[1].split(",")
        self.dims=[float(dim) for dim in dims]
    
    def show(self,IDs):
        # For SI's, calibration is a dictionary with both the energy
        # calibration for the spectra and the spatial calibration for the array
        # of spectra.
        fig=plt.figure(figsize=(10,5))
        if type(IDs).__name__=='list':
            rows=1+np.ceil(len(IDs)/3.)
            sp_top=fig.add_subplot(rows,1,1)
            for i in xrange(len(IDs)):
                self.specPlot(IDs[i],label=IDs[i],subplot=sp_top)
                sp=fig.add_subplot(rows,3,4+i)
                self.imgPlot(IDs[i],subplot=sp)
                plt.colorbar(orientation='horizontal')
        else:
            sp_top=fig.add_subplot(2,1,1)
            self.specPlot(IDs,label=IDs,subplot=sp_top)
            sp_bottom=fig.add_subplot(2,1,2)
            self.imgPlot(IDs[i],subplot=sp_bottom)
            plt.colorbar(orientation='horizontal')
            
'''
class FactorImages(Images):
    """
    Class representing factor data derived from image data (stacks of images).
    
    Factors are images themselves - each of the original images is made up by
    adding together score-weighted factor images.
    """
    
class FactorSpectra(Factors, Images, Spectra):
    """
    Class representing factor data derived from spectral data (stacks of
    spectra, either from linescans or simply several points).
    
    Factors are spectra themselves - each of the original spectra is made up by
    adding together score-weighted factor spectra.
    """    
        
class FactorSI(Factors, SpectrumImages):
    """
    Class representing factor data derived from spectrum image data (2D arrays
    of spectra).
    
    Factors are spectra - each of the pixels in the SI represents a spectrum
    that is made up by adding together score-weighted factor spectra.  The score
    """
'''


        
        
