from srwlib import *

class SRWAdapter(object):
    def undulator(self, undulator):

        magnetic_fields = []

        if undulator.K_vertical() > 0.0:
            print("Kv",undulator.K_vertical())
            print("Bv",undulator.B_vertical())
            vertical_field = SRWLMagFldH(1, 'v', undulator.B_vertical(), 0, 1, 1)
            magnetic_fields.append(vertical_field)

        if undulator.K_horizontal() > 0.0:
            print("Kh",undulator.K_horizontal())
            print("Bh",undulator.B_horizontal())
            horizontal_field = SRWLMagFldH(1, 'h', undulator.B_horizontal(), 0, -1, 1)
            magnetic_fields.append(horizontal_field)

        srw_undulator = SRWLMagFldU(magnetic_fields,
                                    undulator.periodLength(),
                                    undulator.periodNumber())

        return srw_undulator

    def magnetFieldFromUndulator(self, undulator):

        srw_undulator = self.undulator(undulator)

        magFldCnt = SRWLMagFldC([srw_undulator],
                                array('d', [0]), array('d', [0]), array('d', [0])) #Container of all Field Elements

        return magFldCnt

    def zeroEmittanceElectronBeam(self, E_in_GeV):
        #***********Electron Beam
        elecBeam = SRWLPartBeam()
        elecBeam.Iavg = 0.2 #Average Current [A]
        elecBeam.partStatMom1.x = 0 #Initial Transverse Coordinates (initial Longitudinal Coordinate will be defined later on) [m]
        elecBeam.partStatMom1.y = 0
        elecBeam.partStatMom1.z = 0
        elecBeam.partStatMom1.xp = 0 #Initial Relative Transverse Velocities
        elecBeam.partStatMom1.yp = 0
        elecBeam.partStatMom1.gamma = E_in_GeV/0.51099890221e-03 #Relative Energy

        return elecBeam

    def createQuadraticSRWWavefront(self, grid_size, grid_length, z_start, electron_beam, energy_number, energy_start, energy_end):
        wfr = SRWLWfr()
        wfr.allocate(energy_number, grid_size, grid_size) #Numbers of points vs Photon Energy, Horizontal and Vertical Positions (may be modified by the library!)
        wfr.mesh.zStart = float(z_start)      #Longitudinal Position [m] at which SR has to be calculated
        wfr.mesh.eStart = energy_start        #1090. #Initial Photon Energy [eV]
        wfr.mesh.eFin   = energy_end          #1090. #Final Photon Energy   [eV]
        wfr.mesh.xStart = -grid_length        #Initial Horizontal Position  [m]
        wfr.mesh.xFin   =  grid_length        #Final Horizontal Position    [m]
        wfr.mesh.xFin   =  grid_length        #Final Horizontal Position    [m]
        wfr.mesh.yStart = -grid_length        #Initial Vertical Position    [m]
        wfr.mesh.yFin   =  grid_length        #Final Vertical Position      [m]

        wfr.partBeam = electron_beam

        return wfr

    def createQuadraticSRWWavefrontSingleEnergy(self, grid_size, grid_length, z_start, electron_beam, energy):
        return self.createQuadraticSRWWavefront(grid_size,grid_length, z_start, electron_beam,1,energy,energy)

    def createBeamlineOneToOneSourceImage(self,wfr):
        #***********Optical Elements and Propagation Parameters
        fx = wfr.mesh.zStart/2               #Lens focal lengths
        fy = wfr.mesh.zStart/2
        optLens = SRWLOptL(fx, fy)           #Lens
        optDrift = SRWLOptD(wfr.mesh.zStart) #Drift space

        propagParLens = [1, 1, 1., 0, 0, 1., 1., 1., 1., 0, 0, 0]
        propagParDrift = [1, 1, 1., 0, 0, 2., 2., 2., 2., 0, 0, 0]
        #Wavefront Propagation Parameters:
        #[0]: Auto-Resize (1) or not (0) Before propagation
        #[1]: Auto-Resize (1) or not (0) After propagation
        #[2]: Relative Precision for propagation with Auto-Resizing (1. is nominal)
        #[3]: Allow (1) or not (0) for semi-analytical treatment of the quadratic (leading) phase terms at the propagation
        #[4]: Do any Resizing on Fourier side, using FFT, (1) or not (0)
        #[5]: Horizontal Range modification factor at Resizing (1. means no modification)
        #[6]: Horizontal Resolution modification factor at Resizing
        #[7]: Vertical Range modification factor at Resizing
        #[8]: Vertical Resolution modification factor at Resizing
        #[9]: Type of wavefront Shift before Resizing (not yet implemented)
        #[10]: New Horizontal wavefront Center position after Shift (not yet implemented)
        #[11]: New Vertical wavefront Center position after Shift (not yet implemented)
        optical_beamline = SRWLOptC([optLens, optDrift],
                                    [propagParLens, propagParDrift]) #"Beamline" - Container of Optical Elements (together with the corresponding wavefront propagation instructions)

        return optical_beamline

    def normalPrecisionParameter(self):
        #***********Precision Parameters for SR calculation
        meth        = 1 #SR calculation method: 0- "manual", 1- "auto-undulator", 2- "auto-wiggler"
        relPrec     = 0.01 #relative precision
        zStartInteg = 0 #longitudinal position to start integration (effective if < zEndInteg)
        zEndInteg   = 0 #longitudinal position to finish integration (effective if > zStartInteg)
        npTraj      = 20000 #Number of points for trajectory calculation
        useTermin   = 1 #Use "terminating terms" (i.e. asymptotic expansions at zStartInteg and zEndInteg) or not (1 or 0 respectively)
        sampFactNxNyForProp = 1 #sampling factor for adjusting nx, ny (effective if > 0)
        arPrecPar = [meth, relPrec, zStartInteg, zEndInteg, npTraj, useTermin, sampFactNxNyForProp]

        return arPrecPar
