from srwlib import *
import os
import copy

from orangecontrib.srw.drivers.AbstractDriver import AbstractDriver
from orangecontrib.srw.drivers.srw.SRWDriverData import SRWDriverData
from orangecontrib.srw.drivers.DriverSettingAttribute import DriverSettingAttribute
from orangecontrib.srw.drivers.DriverSettings import DriverSettings

from orangecontrib.srw.util.OpticalElementSourceGaussian import OpticalElementSourceGaussian
from orangecontrib.srw.util.OpticalElementAperture import Disc, Rectangle
from orangecontrib.srw.util.OpticalElementScreen import OpticalElementScreen
from orangecontrib.srw.util.OpticalElementSpace import OpticalElementSpace
from orangecontrib.srw.util.Polarization import LinearVertical, LinearHorizontal


class SRWDriver(AbstractDriver):
    
    def __init__(self):
        self.srw_wavefront = None

    def addToWavefront(self, optical_element, in_data, typename):
        wavefront = in_data.wavefront()
        wavefront = "%s - %s:%s" % (wavefront, typename, optical_element.name())
        
        out_data = self.createData()
        print(wavefront)
        out_data.setWavefront(wavefront)
        
        return out_data


    def hashInputData(self, optical_element, in_data):
        """
        Hashes the input data.
        """
        return hash(in_data.wavefront())

    def createData(self):
        """
        Factory method for driver data.
        """
        return SRWDriverData()

    def _calculateDataSource(self, optical_element, in_data):
        """
        Calculates output data from input data for a source.
        """

        if isinstance(optical_element, OpticalElementSourceGaussian):
            gaussian_beam = optical_element
            
            
            srw_gaussian_beam    = SRWLGsnBm()
            
            srw_gaussian_beam.x  = int(gaussian_beam.x())
            srw_gaussian_beam.y  = int(gaussian_beam.y())
            srw_gaussian_beam.z  = int(gaussian_beam.z())
            srw_gaussian_beam.xp = int(gaussian_beam.xp())
            srw_gaussian_beam.yp = int(gaussian_beam.yp())
            srw_gaussian_beam.avgPhotEn = int(gaussian_beam.averagePhotonEnergy())
            srw_gaussian_beam.pulseEn   = gaussian_beam.pulseEnergy()
            srw_gaussian_beam.repRate   = int(gaussian_beam.repititionRate())
            
            polarization = gaussian_beam.polarization()
            if polarization==LinearVertical():
                srw_gaussian_beam.polar = 2
            elif polarization==LinearHorizontal(): 
                srw_gaussian_beam.polar = 1
            elif polarization==Linear45Degree(): 
                srw_gaussian_beam.polar = 3
            elif polarization==Linear135Degree(): 
                srw_gaussian_beam.polar = 4
            elif polarization==CircularRight(): 
                srw_gaussian_beam.polar = 5
            elif polarization==CircularLeft(): 
                srw_gaussian_beam.polar = 6               
            else:
                raise Exception("Polarisation %s not handled." % polarization)
            
            
            
            
            srw_gaussian_beam.sigX  = gaussian_beam.sigmaX()
            srw_gaussian_beam.sigY  = gaussian_beam.sigmaY()
            srw_gaussian_beam.sigT  = gaussian_beam.sigmaT()
            srw_gaussian_beam.mx    = 0
            srw_gaussian_beam.my    = 0
            
            
            srw_wavefront = SRWLWfr()             #Initial Electric Field Wavefront
            srw_wavefront.allocate(1, 5000, 5000) #Numbers of points vs Photon Energy (1), Horizontal and Vertical Positions (dummy)
            srw_wavefront.mesh.zStart = 300       #Longitudinal Position [m] at which Electric Field has to be calculated, i.e. the position of the first optical element
            srw_wavefront.mesh.eStart = srw_gaussian_beam.avgPhotEn #Initial Photon Energy [eV]
            srw_wavefront.mesh.eFin   = srw_gaussian_beam.avgPhotEn #Final Photon Energy [eV]
            firstHorAp  = 1.e-03 #First Aperture [m]
            firstVertAp = 1.e-03 #[m] 
            srw_wavefront.mesh.xStart = -0.5*firstHorAp  #Initial Horizontal Position [m]
            srw_wavefront.mesh.xFin   =  0.5*firstHorAp   #Final Horizontal Position [m]
            srw_wavefront.mesh.yStart = -0.5*firstVertAp #Initial Vertical Position [m]
            srw_wavefront.mesh.yFin   =  0.5*firstVertAp  #Final Vertical Position [m]
            
            srw_wavefront.partBeam.partStatMom1.x  = srw_gaussian_beam.x #Some information about the source in the Wavefront structure
            srw_wavefront.partBeam.partStatMom1.y  = srw_gaussian_beam.y
            srw_wavefront.partBeam.partStatMom1.z  = srw_gaussian_beam.z
            srw_wavefront.partBeam.partStatMom1.xp = srw_gaussian_beam.xp
            srw_wavefront.partBeam.partStatMom1.yp = srw_gaussian_beam.yp

        else:
            raise Exception("Source not supported %s" % str(optical_element))

        sampFactNxNyForProp = 5 #sampling factor for adjusting nx, ny (effective if > 0)
        arPrecPar = [sampFactNxNyForProp]
        
        # Calculate initial wavefront.
        srwl.CalcElecFieldGaussian(srw_wavefront, srw_gaussian_beam, arPrecPar)
        arI0 = array('f', [0]*srw_wavefront.mesh.nx*srw_wavefront.mesh.ny) #"flat" array to take 2D intensity data
        srwl.CalcIntFromElecField(arI0, srw_wavefront, 6, 0, 3, srw_wavefront.mesh.eStart, 0, 0) #extracts intensity

        arP0 = array('d', [0]*srw_wavefront.mesh.nx*srw_wavefront.mesh.ny) #"flat" array to take 2D phase data (note it should be 'd')
        srwl.CalcIntFromElecField(arP0, srw_wavefront, 0, 4, 3, srw_wavefront.mesh.eStart, 0, 0) #extracts radiation phase


        out_data = self.createData()
        out_data.setWavefront(srw_wavefront)
        return out_data

    def propagateWavefront(self, optical_element, in_data, srw_optical_element):
        driver_settings = optical_element.driverSettings()
        
        if driver_settings is None:
            print("%s uses default settings" % optical_element)
            driver_settings = self.driverSettings()
        
        srw_preferences = self._driverSettingsToSRWParameters(driver_settings)
        
        
    
        optical_beamline=SRWLOptC([srw_optical_element],
                                  [srw_preferences])
        
#        wavefront = in_data.wavefront()#copy.deepcopy(in_data.wavefront())
        wavefront = copy.deepcopy(in_data.wavefront())
        
        srwl.PropagElecField(wavefront, optical_beamline)
        out_data = self.createData()
        out_data.setWavefront(wavefront)
        return out_data

    def _calculateDataSpace(self, optical_element, in_data):
        """
        Calculates output data from input data for a drift space.
        """
        srw_drift_space = SRWLOptD(optical_element.length())
        result = self.propagateWavefront(optical_element,in_data, srw_drift_space)
        return result
    
    def _calculateDataAperture(self, optical_element, in_data):
        """
        Calculates output data from input data for a aperture.
        """
        
        aperture_type = optical_element.apertureType()
        if aperture_type == Disc():
            srw_aperture = SRWLOptA('c', 'a', 
                                    optical_element.diameter(), 
                                    optical_element.diameter())
        elif aperture_type == Rectangle():
            srw_aperture = SRWLOptA('r', 'a', 
                                    optical_element.width(), 
                                    optical_element.height())
        else:
            raise Exception("Aperature type %s not handled" % str(aperture_type))
            
        result = self.propagateWavefront(optical_element,in_data, srw_aperture)
        return result  
    
    def _calculateDataLens(self, optical_element, in_data):
        """
        Calculates output data from input data for a lens.
        """
        srw_lens = SRWLOptL(optical_element.focalX(),
                                   optical_element.focalY())
        result = self.propagateWavefront(optical_element,in_data, srw_lens)
        return result
    
    def _calculateDataScreen(self, optical_element, in_data):
        """
        Calculates output data from input data for a screen.
        """
        return copy.deepcopy(in_data)
    
    def driverSettings(self):
        #***********Wavefront Propagation Parameters:
        #[0]: Auto-Resize (1) or not (0) Before propagation
        #[1]: Auto-Resize (1) or not (0) After propagation
        #[2]: Relative Precision for propagation with Auto-Resizing (1. is nominal)
        #[3]: Allow (1) or not (0) for semi-analytical treatment of the quadratic (leading) phase terms at the propagation
        #[4]: Do any Resizing on Fourier side, using FFT, (1) or not (0)
        #[5]: Horizontal Range modification factor at Resizing (1. means no modification)
        #[6]: Horizontal Resolution modification factor at Resizing
        #[7]: Vertical Range modification factor at Resizing
        #[8]: Vertical Resolution modification factor at Resizing
        #[9]: Type of wavefront Shift before Resizing (not yet implemented)
        #[10]: New Horizontal wavefront Center position after Shift (not yet implemented)
        #[11]: New Vertical wavefront Center position after Shift (not yet implemented)

        attributes = [
            DriverSettingAttribute("Auto resize before propagation",
                                   "Auto resize before propagation",
                                   bool,
                                   False),
            DriverSettingAttribute("Auto resize after propagation",
                                   "Auto resize after propagation",
                                   bool,
                                   False),
            DriverSettingAttribute("Relative precision for auto resize",
                                   "Relative Precision for propagation with Auto-Resizing (1. is nominal)",
                                   float,
                                   1.0),
            DriverSettingAttribute("Semi-analytical treatment of leading phase",
                                   "Allow (1) or not (0) for semi-analytical treatment of the quadratic (leading) phase terms at the propagation",
                                   bool,
                                   True),
            DriverSettingAttribute("Resize using FFT",
                                   "Do any Resizing on Fourier side, using FFT",
                                   bool,
                                   False),
            DriverSettingAttribute("Horizontal scaling",
                                   "Horizontal Range modification factor at Resizing (1. means no modification)",
                                   float,
                                   1.0),
            DriverSettingAttribute("Horizontal resolution",
                                   "Horizontal Resolution modification factor at Resizing",
                                   float,
                                   1.0),
            DriverSettingAttribute("Vertical scaling",
                                   "Vertical Range modification factor at Resizing (1. means no modification)",
                                   float,
                                   1.0),
            DriverSettingAttribute("Vertical resolution",
                                   "Vertical Resolution modification factor at Resizing",
                                   float,
                                   1.0),
            DriverSettingAttribute("Type of wavefront Shift",
                                   "Type of wavefront Shift before Resizing (not yet implemented)",
                                   int,
                                   0),
            DriverSettingAttribute("New Horizontal wavefront Center",
                                   "New Horizontal wavefront Center position after Shift (not yet implemented)",
                                   int,
                                   0),
            DriverSettingAttribute("New Vertical wavefront Center",
                                   "New Vertical wavefront Center position after Shift (not yet implemented)",
                                   int,
                                   0),
                      ]                      
        
        return DriverSettings(attributes)
            
    def _driverSettingsToSRWParameters(self, driver_settings):
        
        result = []
        for name in self.driverSettings().names():
            value = driver_settings.valueByName(name)
            
            if type(value) is bool:
                value = int(value)
            
            result.append(value)

        return result

    def calculateIntensity3D(self, in_data):
        """
        Calculates 3D intensity distribution.
        """
        wfr = in_data.wavefront()
        mesh3 = deepcopy(wfr.mesh)
        arI3 = array('f', [0]*mesh3.nx*mesh3.ny) #"flat" array to take 2D intensity data
        srwl.CalcIntFromElecField(arI3, wfr, 6, 0, 3, mesh3.eStart, 0, 0) #extracts intensity
        plotMesh3x = [1e+06*mesh3.xStart, 1e+06*mesh3.xFin, mesh3.nx]
        plotMesh3y = [1e+06*mesh3.yStart, 1e+06*mesh3.yFin, mesh3.ny]
        
        return [arI3, plotMesh3x, plotMesh3y]
        
    def calculatePhase3D(self, in_data):
        """
        Calculates 3D phase distribution.
        """        
        wfr = in_data.wavefront()
        mesh3 = deepcopy(wfr.mesh)

        arP3 = array('d', [0]*mesh3.nx*mesh3.ny) #"flat" array to take 2D phase data (note it should be 'd')
        srwl.CalcIntFromElecField(arP3, wfr, 0, 4, 3, mesh3.eStart, 0, 0) #extracts radiation phase
        plotMesh3x = [1e+06*mesh3.xStart, 1e+06*mesh3.xFin, mesh3.nx]
        plotMesh3y = [1e+06*mesh3.yStart, 1e+06*mesh3.yFin, mesh3.ny]
        
        return [arP3, plotMesh3x, plotMesh3y]
