#!/usr/bin/env python
# -*- coding: iso-8859-1 -*-
"""
Define peak positions (manually for now)
locate matches
    use cross correlation maxima, not peak maxima
compute characteristics for each peak:
    1. compute average peak position
    2. deviation from average
    3. image moments - orientation
    4. image moments - eccentricity
    5. fit Gaussian for peak height?
MSA of characteristics:
    1. image per column, each row is peak char

Copyright (C) 2011 by Michael Sarahan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import cv
import numpy as np
from scipy import optimize
from scipy.signal import medfilt
try:
    import mda_rpy as msa
except:
    import mdapyper as msa

def gaussian(height, center_x, center_y, width_x, width_y):
    """Returns a gaussian function with the given parameters"""
    width_x = float(width_x)
    width_y = float(width_y)
    return lambda x,y: height*np.exp(
                -(((center_x-x)/width_x)**2+((center_y-y)/width_y)**2)/2)

def moments(data):
    """Returns (height, x, y, width_x, width_y)
    the gaussian parameters of a 2D distribution by calculating its
    moments """
    m=cv.Moments(cv.fromarray(data))
    x = cv.GetSpatialMoment(m,0,1)/cv.GetSpatialMoment(m,0,0)
    y = cv.GetSpatialMoment(m,1,0)/cv.GetSpatialMoment(m,0,0)
    col = data[:, int(y)]
    width_x = np.sqrt(np.abs((np.arange(col.size)-y)**2*col).sum()/col.sum())
    row = data[int(x), :]
    width_y = np.sqrt(np.abs((np.arange(row.size)-x)**2*row).sum()/row.sum())
    height = data.max()
    return height, x, y, width_x, width_y

def fitgaussian(data):
    """Returns (height, x, y, width_x, width_y)
    the gaussian parameters of a 2D distribution found by a fit"""
    params = moments(data)
    errorfunction = lambda p: np.ravel(gaussian(*p)(*np.indices(data.shape)) -
                                 data)
    p, success = optimize.leastsq(errorfunction, params)
    return p

def one_dim_findpeaks(y, x=None, SlopeThreshold=0.5, AmpThreshold=None,
              smoothwidth=5, peakgroup=10, subpix=True):
    """
    Find peaks along a 1D line.
    function P=findpeaks(y,SlopeThreshold,AmpThreshold,smoothwidth,peakgroup)
    Function to locate the positive peaks in a noisy x-y data
    set.  Detects peaks by looking for downward zero-crossings
    in the first derivative that exceed SlopeThreshold.
    Returns list (P) containing peak number and
    position, height, and width of each peak. SlopeThreshold,
    AmpThreshold, and smoothwidth control sensitivity
    Higher values will neglect smaller features. Peakgroup
    is the number of points around the "top part" of the peak.
    T. C. O'Haver, 1995.  Version 2  Last revised Oct 27, 2006
    Converted to Python by Michael Sarahan, Feb 2011.
    """
    if not x:
        x=np.arange(len(y),dtype=np.int64)
    if not AmpThreshold:
        AmpThreshold = 0.1 * y.max()
    peakgroup=np.round(peakgroup)
    #d=medfilt(np.gradient(y),smoothwidth)
    if smoothwidth:
        d=np.gradient(medfilt(y,smoothwidth))
    else:
        d=np.gradient(y)
    n=np.round(peakgroup/2+1)
    # allocate a result array for 3000 peaks.
    P=np.zeros((3000,3))
    peak=0;
    for j in xrange(len(y)-4):
        if np.sign(d[j]) > np.sign(d[j+1]): # Detects zero-crossing
            if np.sign(d[j+1])==0: continue
            if d[j]-d[j+1] > SlopeThreshold: # if slope of derivative is larger than SlopeThreshold
                if y[j] > AmpThreshold:  # if height of peak is larger than AmpThreshold
                    if subpix:
			xx=np.zeros(peakgroup)
			yy=np.zeros(peakgroup)
			s=0
			for k in xrange(peakgroup): 
			    groupindex=j+k-n+1; 
			    if groupindex<1:
				xx=xx[1:]
				yy=yy[1:]
				s+=1
				continue
			    elif groupindex>y.shape[0]-1:
				xx=xx[:groupindex-1]
				yy=yy[:groupindex-1]
				break
			    xx[k-s]=x[groupindex]
			    yy[k-s]=y[groupindex]
			avg=np.average(xx)
			stdev=np.std(xx)
			xxf=(xx-avg)/stdev
			# Fit parabola to log10 of sub-group with centering and scaling
			coef=np.polyfit(xxf,np.log10(np.abs(yy)),2)  
			c1=coef[2]
			c2=coef[1]
			c3=coef[0]
			width=np.linalg.norm(stdev*2.35703/(np.sqrt(2)*np.sqrt(-1*c3)))
			# if the peak is too narrow for least-squares technique to work
			# well, just use the max value of y in the sub-group of points near peak.
			if peakgroup<7:
			    height=np.max(yy);
			    location=xx[np.argmin(np.abs(yy-height))]
			else:
                            location=-((stdev*c2/(2*c3))-avg)
			    height=np.exp(c1-c3*(c2/(2*c3))**2)    
                    # Fill results array P. One row for each peak 
                    # detected, containing the peak number, peak 
                    # position (x-value) and peak height (y-value).
		    else:
			location = x[j]
			height   = y[j]
			width    = 0
                    if (location > 0 and not np.isnan(location) and location < x[-1]):
                        P[peak] = np.array([location, height, width])
                        #plt.axvline(location)
                        peak=peak+1;
    return P[:peak,:]

def two_dim_findpeaks(arr,subpixel=False,boxsize=10,smoothwidth=5):
    # Code based on Dan Masiel's matlab functions
    mapX=np.zeros_like(arr)
    mapY=np.zeros_like(arr)
    arr=medfilt(arr,smoothwidth)
    xc = [one_dim_findpeaks(arr[i], smoothwidth=None,
                             peakgroup=boxsize,
                             subpix=False)[:,0] for i in xrange(arr.shape[1])]
    for row in xrange(len(xc)):
        for col in xrange(xc[row].shape[0]):
            mapX[row,int(xc[row][col])]=1
    yc = [one_dim_findpeaks(arr[:,i], smoothwidth=None,
                             peakgroup=boxsize,
                             subpix=False)[:,0] for i in xrange(arr.shape[0])]
    for row in xrange(len(yc)):
        for col in xrange(yc[row].shape[0]):
            mapY[row,int(yc[row][col])]=1
    # Dan's comment from Matlab code, left in for curiosity:
    #% wow! lame!
    mapX = mapX.reshape((arr.shape))
    mapY = mapY.reshape((arr.shape))
    Fmap = mapX*mapY.T
    nonzeros=np.nonzero(Fmap)
    coords=np.vstack((nonzeros[1],nonzeros[0])).T
    if subpixel:
        coords=subpix_locate(arr,coords,boxsize)
    coords=np.ma.fix_invalid(coords,fill_value=-1)
    coords=np.ma.masked_outside(coords,boxsize/2+1,arr.shape[0]-boxsize/2-1)
    coords=np.ma.masked_less(coords,0)
    coords=np.ma.compress_rows(coords)
    return coords 

def subpix_locate(data,points,boxsize,scale=None):
    from scipy.ndimage.measurements import center_of_mass as CofM
    top=left=boxsize/2
    centers=np.array(points,dtype=np.float32)
    for i in xrange(points.shape[0]):
        pt=points[i]
        center=np.array(CofM(data[(pt[0]-left):(pt[0]+left),(pt[1]-top):(pt[1]+top)]))
        center=center[0]-boxsize/2,center[1]-boxsize/2
        centers[i]=np.array([pt[0]+center[0],pt[1]+center[1]])
    if scale:
        centers=centers*scale
    return centers
    
def stack_coords(stack,peakwidth,subpixel=False):
    """
    A rough location of all peaks in the image stack.  This can be fed into the
    best_match function with a list of specific peak locations to find the best
    matching peak location in each image.
    """
    depth=stack.shape[1]
    imsize=np.sqrt(stack.shape[0])
    coords=np.ones((500,2,depth))*10000
    for i in xrange(depth):
        ctmp=two_dim_findpeaks(stack[:,i].reshape((imsize,-1)),
                       subpixel=subpixel,boxsize=peakwidth)
        for row in xrange(ctmp.shape[0]):
            coords[row,:,i]=ctmp[row]
    return coords
    
def best_match(arr,target,neighborhood=None):
    """
    Attempts to find the best match for target in array arr.  Assumes a 3D array,
    consisting of peak coordinates from each image.  returns an array with the
    best matching coordinates for each image.
    
    Usage:
        best_match(arr, target)
    """
    depth=arr.shape[2]
    rlt=np.zeros((depth,2))
    arr_sub=arr.copy()
    for d in xrange(depth):
        arr_sub[:,:,d]=arr[:,:,d]-target
    if neighborhood:
        arr_sub=np.ma.masked_outside(arr_sub,-neighborhood,neighborhood)
        arr_sub=np.ma.filled(arr_sub,100)
    matches=[np.argmin(
                np.sqrt(np.sum(
                    np.power(arr_sub[:,:,i],2),
                    axis=1))
                ) for i in xrange(depth)]
    for i in xrange(depth):
        rlt[i]=arr[matches[i],:,i]
        if np.sum(rlt[i])>2*neighborhood:
            print "Warning! Didn't find a peak within your neighborhood! Watch for fishy peaks."
    return rlt
  
def peakAttribs2(image,locations,peakwidth):
    rlt=np.zeros((locations.shape[0],5))
    r=peakwidth/2
    # define the image size for square images
    for loc in xrange(locations.shape[0]):
        c=locations[loc]
        bxmin=c[0]-r
        bymin=c[1]-r
        bxmax=c[0]+r
        bymax=c[1]+r
        roi=image[bxmin:bxmax,bymin:bymax]
        ms=cv.Moments(cv.fromarray(roi))
        height=fitgaussian(roi)[0]
        orient=orientation(ms)
        ecc=eccentricity(ms)
        rlt[loc,:2]=c
        rlt[loc,2]=height
        rlt[loc,3]=orient
        rlt[loc,4]=ecc
    return rlt
        
def peakAttribs(stack,locations,peakwidth,imcoords=None):
    """
    Given a stack of images and a list of locations and window sizes (defined by
    the peak width), measure the peak attributes of the peaks of interest in
    each image.  These attributes currently include the height, location of the 
    peak and the relative difference in position of the peak from the average, 
    peak orientation angle and eccentricity.
    """
    # pre-allocate result array.  7 rows for each peak, 1 column for each image
    if imcoords:
        rlt=np.zeros((7*locations.shape[0]+2,stack.shape[1]))
    else:
        rlt=np.zeros((7*locations.shape[0],stack.shape[1]))
    r=peakwidth/2
    # define the image size for square images
    imsize=np.sqrt(stack.shape[0])
    locavg=np.average(locations,axis=0)
    diffc=locations-locavg
    for loc in xrange(locations.shape[0]):
        for i in xrange(stack.shape[1]):
            c=locations[loc,i,:]
            dc=diffc[loc,i,:]
            bxmin=c[0]-r
            bymin=c[1]-r
            bxmax=c[0]+r
            bymax=c[1]+r
            roi=stack[:,i].reshape((imsize,-1))[bxmin:bxmax,bymin:bymax]
            ms=cv.Moments(cv.fromarray(roi))
            height=fitgaussian(roi)[0]
            orient=orientation(ms)
            ecc=eccentricity(ms)
            rlt[loc*7:loc*7+2,i]=c
            rlt[loc*7+2:loc*7+4,i]=dc
            rlt[loc*7+4,i]=height
            rlt[loc*7+5,i]=orient
            rlt[loc*7+6,i]=ecc
    if imcoords:
        if imcoords.shape[1]==rlt.shape[1]:
            pass
        else:
            imcoords=imcoords.T
        rlt=np.vstack((imcoords,rlt))
    return rlt
            
def orientation(moments):
    mu11p = cv.GetCentralMoment(moments,1,1)/cv.GetCentralMoment(moments,0,0)
    mu02p = cv.GetCentralMoment(moments,2,0)/cv.GetCentralMoment(moments,0,0)
    mu20p = cv.GetCentralMoment(moments,0,2)/cv.GetCentralMoment(moments,0,0)
    return 0.5*np.arctan(2*mu11p/(mu20p-mu02p))

def eccentricity(moments):
    mu11p = cv.GetCentralMoment(moments,1,1)/cv.GetCentralMoment(moments,0,0)
    mu02p = cv.GetCentralMoment(moments,2,0)/cv.GetCentralMoment(moments,0,0)
    mu20p = cv.GetCentralMoment(moments,0,2)/cv.GetCentralMoment(moments,0,0)
    return ((mu20p-mu02p)**2-4*mu11p**2)/(mu20p+mu02p)**2
    
def PCAfilter(data):
    vec,sc,vals=msa.pca(data)
        
def attribMSA(attribs,nFacs):
    """
    Run PCA and/or ICA on the image attribute array
    """
    factors,scores=msa.ica(attribs,nFacs)
    return factors,scores

def main(stack,locations,peakwidth):
    """
    Given a stack of images and a list of locations and window sizes (defined by
    the peak width), measure the peak attributes of the peaks of interest in
    each image.  These attributes currently include the height, location of the peak and the relative
    difference in position of the peak from the average, peak orientation angle
    and eccentricity.
    """

if __name__=='__main__':
    import fread
    from glob import glob
    flist=glob('*.png')
    peakWidth=8
    neighborhood=4
    d=fread.read(flist)
    #locs = np.array([[15,13],[25,14],[35,14],[15,24],[25,24],[35,24],[15,34],[25,34],[35,34]])
    locs = np.array([[5,5],[25,5],[5,25],[25,25],[16,16]])
    pkcs = stack_coords(d,peakWidth,True)
    best=np.array([best_match(pkcs,loc,neighborhood) for loc in locs])
    attribs=peakAttribs(d,best,peakWidth)
    np.save('attribs.npy',attribs)
    #factors,scores=attribMSA(attribs,10)
    
