import numpy

class Wavefront(object):
    def __init__(self, srw_wavefront):
        self._srw_wavefront = srw_wavefront

    def SRWWavefront(self):
        return self._srw_wavefront

    def numberEnergies(self):
        return self._srw_wavefront.mesh.ne

    def dim_x(self):
        return self._srw_wavefront.mesh.nx

    def dim_y(self):
        return self._srw_wavefront.mesh.ny

    def dim_energy(self):
        return self._srw_wavefront.mesh.ne


    def absolute(self, x,y):
        absolute_x = self._srw_wavefront.mesh.xStart + x*(self._srw_wavefront.mesh.xFin - self._srw_wavefront.mesh.xStart) / float(self._srw_wavefront.mesh.nx-1)
        absolute_y = self._srw_wavefront.mesh.yStart + y*(self._srw_wavefront.mesh.yFin - self._srw_wavefront.mesh.yStart) / float(self._srw_wavefront.mesh.ny-1)

        return [absolute_x,absolute_y]

    def absolute_x_coordinates(self):
        x_coordinates = [self._srw_wavefront.mesh.xStart + x*(self._srw_wavefront.mesh.xFin - self._srw_wavefront.mesh.xStart) / float(self._srw_wavefront.mesh.nx-1) for x in range(self.dim_x())]
        return  x_coordinates

    def absolute_y_coordinates(self):
        y_coordinates = [self._srw_wavefront.mesh.yStart + y*(self._srw_wavefront.mesh.yFin - self._srw_wavefront.mesh.yStart) / float(self._srw_wavefront.mesh.ny-1) for y in range(self.dim_y())]
        return  y_coordinates

    def minimal_x_coodinate(self):
        return min(self.absolute_x_coordinates())

    def maximal_x_coodinate(self):
        return max(self.absolute_x_coordinates())

    def minimal_y_coodinate(self):
        return min(self.absolute_y_coordinates())

    def maximal_y_coodinate(self):
        return max(self.absolute_y_coordinates())

    def index(self, x, y, index_energy):
        index = 2*y*self.dim_x() * self.dim_energy() + 2*x * self.dim_energy() + 2*index_energy
        return index

    def efield(self, x, y,index_energy):
        index = self.index(x,y,index_energy)
        e_horizontal = self._srw_wavefront.arEx[index] + self._srw_wavefront.arEx[index+1] * 1j
        e_vertical = self._srw_wavefront.arEy[index] + self._srw_wavefront.arEy[index+1] * 1j

        return [e_horizontal, e_vertical]

    def _srw_array_to_numpy(self, srw_array):
        re=numpy.array(srw_array[::2], dtype=numpy.float)
        im=numpy.array(srw_array[1::2], dtype=numpy.float)

        e = re + 1j * im
        e=e.reshape((self.dim_y(),
                     self.dim_x(),
                     self.numberEnergies(),
                     1)
                    )

        return e

    def E_field_as_numpy(self):
        x_polarization = self._srw_array_to_numpy(self._srw_wavefront.arEx)
        y_polarization = self._srw_array_to_numpy(self._srw_wavefront.arEy)

        e_field = numpy.concatenate((x_polarization,y_polarization),3)

        return e_field

    def interpolate(self, index_energy):
        import scipy.interpolate

        x = self.absolute_x_coordinates()
        y = self.absolute_y_coordinates()

        e_field = self.E_field_as_numpy()

        s_re_x = scipy.interpolate.RectBivariateSpline(y,x, e_field[:,:,index_energy,0].real)
        s_im_x = scipy.interpolate.RectBivariateSpline(y,x, e_field[:,:,index_energy,0].imag)

        s_re_y = scipy.interpolate.RectBivariateSpline(y,x, e_field[:,:,index_energy,1].real)
        s_im_y = scipy.interpolate.RectBivariateSpline(y,x, e_field[:,:,index_energy,1].imag)


        s = lambda x,y : (s_re_x(x,y) + 1j*s_im_x(x,y), s_re_y(x,y) + 1j*s_im_y(x,y))

        return s


    def intensity(self, x, y):
        e_field = self.efield(x,y,0)
        return abs(e_field[0])**2+abs(e_field[1])**2

    def intensity_at_x(self, x):
        intensity = []
        for y in range(self.dim_y()):
            intensity.append(self.intensity(x,y))

        return intensity

    def intensity_at_y(self, y):
        intensity = []
        for x in range(self.dim_x()):
            intensity.append(self.intensity(x,y))

        return intensity

    def intensity_plane(self):
        E_field = self.E_field_as_numpy()
        intensity = abs(E_field) ** 2
        intensity = intensity.sum(3)

        return intensity


    def FT(self):
        #http://stackoverflow.com/questions/24077913/discretized-continuous-fourier-transform-with-numpy
        import numpy as np
        import matplotlib.pyplot as pl

        #Consider function f(t)=1/(t^2+1)
        #We want to compute the Fourier transform g(w)

        #Discretize time t
        t0=-100.
        dt=0.001
        t=np.arange(t0,-t0,dt)
        #Define function
        f=1./(t**2+1.)

        #Compute Fourier transform by numpy's FFT function
        g=np.fft.fft(f)
        #frequency normalization factor is 2*np.pi/dt
        w = np.fft.fftfreq(f.size)*2*np.pi/dt


        #In order to get a discretisation of the continuous Fourier transform
        #we need to multiply g by a phase factor
        g*=dt*np.exp(-complex(0,1)*w*t0)/(np.sqrt(2*np.pi))

        #Plot Result
        pl.scatter(w,g,color="r")
        #For comparison we plot the analytical solution
        pl.plot(w,np.exp(-np.abs(w))*np.sqrt(np.pi/2),color="g")

        pl.gca().set_xlim(-10,10)
        pl.show()
        pl.close()

    def _plotTest(self):
        import pylab
        from  srwlib import srwl,array
        from copy import deepcopy

        wfr = self.SRWWavefront()

        y=int(self.dim_y()/2)
        inten = self.intensity_at_y(y)
        cor = self.absolute_x_coordinates()

        # I_y(x)
        s = self.interpolate(0)

        e = s([0.0],self.absolute_x_coordinates())

        i = e[0].real**2+e[0].imag**2 + e[1].real**2+e[1].imag**2

        print (i)

        mesh0 = deepcopy(wfr.mesh)
        arI0 = array('f', [0]*mesh0.nx*mesh0.ny) #"flat" array to take 2D intensity data
        srwl.CalcIntFromElecField(arI0, wfr, 6, 0, 3, mesh0.eStart, 0, 0)

        arI0x = array('f', [0]*mesh0.nx) #array to take 1D intensity data (vs X)
        srwl.CalcIntFromElecField(arI0x, wfr, 6, 0, 1, mesh0.eStart, 0, 0)
        arI0y = array('f', [0]*mesh0.ny) #array to take 1D intensity data (vs Y)
        srwl.CalcIntFromElecField(arI0y, wfr, 6, 0, 2, mesh0.eStart, 0, 0)
        print('done')

        pylab.plot(cor, i[0,:], cor, inten,cor, arI0x)
        pylab.show()

        e = s(self.absolute_x_coordinates(),self.absolute_y_coordinates())

        # I_x(y)
        cor = self.absolute_y_coordinates()
        e = s(cor,[0.0])

        i = e[0].real**2+e[0].imag**2 + e[1].real**2+e[1].imag**2

        print (i)

        mesh0 = deepcopy(wfr.mesh)
        arI0 = array('f', [0]*mesh0.nx*mesh0.ny) #"flat" array to take 2D intensity data
        srwl.CalcIntFromElecField(arI0, wfr, 6, 0, 3, mesh0.eStart, 0, 0)

        arI0x = array('f', [0]*mesh0.nx) #array to take 1D intensity data (vs X)
        srwl.CalcIntFromElecField(arI0x, wfr, 6, 0, 1, mesh0.eStart, 0, 0)
        arI0y = array('f', [0]*mesh0.ny) #array to take 1D intensity data (vs Y)
        srwl.CalcIntFromElecField(arI0y, wfr, 6, 0, 2, mesh0.eStart, 0, 0)
        print('done')

        pylab.plot(cor, i[:,0],cor, arI0y)
        pylab.show()

        # I(x,y)
        #e = s(my_wavefront.absolute_x_coordinates(),my_wavefront.absolute_y_coordinates())
        e=self.E_field_as_numpy()

        i = e[:,:,0,0].real**2+e[:,:,0,0].imag**2 + e[:,:,0,1].real**2+e[:,:,0,1].imag**2


        from mpl_toolkits.mplot3d import Axes3D
        import matplotlib.pyplot as plt
        from scipy import meshgrid, array

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        X, Y = meshgrid(self.absolute_y_coordinates(),
                        self.absolute_x_coordinates())

        zs = array(i)
        print(zs.shape)
        Z = zs.reshape(X.shape)

        ax.plot_surface(X, Y, Z)

        ax.set_xlabel('X in plane')
        ax.set_ylabel('Y in plane')
        ax.set_zlabel('Intensity')

        plt.show()