"""
Time stepping with Cython (:mod:`fluiddyn.simul.base.time_stepping_cy`)
=======================================================================

.. currentmodule:: fluiddyn.simul.base.time_stepping_cy

Provides:

.. autoclass:: ExactCoefsDiss
   :members:
   :private-members:

.. autoclass:: ExactCoefsExactLin
   :members:
   :private-members:


.. autoclass:: TimeSteppingPseudoSpectralCython
   :members:
   :private-members:

.. autoclass:: TimeSteppingExactLinCython
   :members:
   :private-members:



"""


cimport numpy as np
import numpy as np
np.import_array()

from time import time, sleep
import datetime
import os
import matplotlib.pyplot as plt
import cython

from libc.math cimport exp

from fluiddyn.simul.operators.setofvariables import SetOfVariables


# we define python and c types for physical and Fourier spaces
DTYPEb = np.uint8
ctypedef np.uint8_t DTYPEb_t
DTYPEi = np.int
ctypedef np.int_t DTYPEi_t
DTYPEf = np.float64
ctypedef np.float64_t DTYPEf_t
DTYPEc = np.complex128
ctypedef np.complex128_t DTYPEc_t

# Basically, you use the _t ones when you need to declare a type
# (e.g. cdef foo_t var, or np.ndarray[foo_t, ndim=...]. Ideally someday
# we won't have to make this distinction, but currently one is a C type
# and the other is a python object representing a numpy type (a dtype),
# and there's currently no way to identify the two without special
# compiler support.
# - Robert Bradshaw


cdef extern from "complex.h":
    # double complex cexp(double complex z) nogil
    np.complex128_t cexp(np.complex128_t z) nogil








class ExactCoefsDiss(object):
    """Handle the computation of the exact coefficient for the RK4."""

    def __init__(self, time_stepping):
        self.time_stepping = time_stepping
        sim = time_stepping.sim
        self.nk = sim.state.state_fft.nb_variables
        self.n0 = sim.oper.nK0_loc
        self.n1 = sim.oper.nK1_loc
        self.exact = self.empty()
        self.exact2 = self.empty()

        if sim.params.time_stepping.USE_CFL:
            self.get_updated_coefs = self.get_updated_coefs_CLF
            self.dt_old = 0.
        else:
            self.compute(time_stepping.deltat)
            self.get_updated_coefs = self.get_coefs



    def empty(self):
        return np.empty([self.n0, self.n1])

    def compute(self, double dt):
        cdef Py_ssize_t i0, i1, n0, n1
        cdef np.ndarray[double, ndim=2] exact, exact2, f_lin

        n0 = self.n0
        n1 = self.n1
        exact = self.exact
        exact2 = self.exact2
        f_lin = self.time_stepping.freq_lin

        for i0 in xrange(n0):
            for i1 in xrange(n1):
                exact[i0, i1]  = exp(-f_lin[i0, i1]*dt)
                exact2[i0, i1] = exp(-f_lin[i0, i1]*dt/2)
        self.dt_old = dt

    def get_updated_coefs_CLF(self):
        dt = self.time_stepping.deltat
        if self.dt_old != dt:
            self.compute(dt)
        return self.exact, self.exact2

    def get_coefs(self):
        return self.exact, self.exact2




class ExactCoefsExactLin(ExactCoefsDiss):
    """Handle the computation of the exact coefficient for the RK4."""

    def empty(self):
        return np.empty([self.nk, self.n0, self.n1], dtype=DTYPEc)

    @cython.boundscheck(False)
    @cython.wraparound(False)
    def compute(self, double dt):
        cdef Py_ssize_t i0, i1, ik, nk, n0, n1
        cdef np.ndarray[DTYPEc_t, ndim=3] exact, exact2, f_lin

        nk = self.nk
        n0 = self.n0
        n1 = self.n1
        exact = self.exact
        exact2 = self.exact2
        f_lin = self.time_stepping.freq_lin.data

        for ik in xrange(nk):
            for i0 in xrange(n0):
                for i1 in xrange(n1):
                    exact[ik, i0, i1]  = cexp(-f_lin[ik, i0, i1]*dt)
                    exact2[ik, i0, i1] = cexp(-f_lin[ik, i0, i1]*dt/2)

        self.dt_old = dt






class TimeSteppingPseudoSpectralCython(object):

    @cython.boundscheck(False)
    @cython.wraparound(False)
    def _time_step_RK4(self):
        """Advance in time *sim.state.state_fft* with the Runge-Kutta 4 method.

        See :ref:`the pure python RK4 function <rk4timescheme>` for the
        presentation of the time scheme.

        For this function, the coefficient :math:`\sigma` is real and
        represents the dissipation.

        """
        # cdef DTYPEf_t dt = self.deltat
        cdef double dt = self.deltat

        cdef Py_ssize_t i0, i1, ik, nk, n0, n1

        # cdef np.ndarray[DTYPEf_t, ndim=2] exact, exact2
        # This is strange, if I use DTYPEf_t and complex.h => bug
        cdef np.ndarray[double, ndim=2] exact, exact2

        cdef np.ndarray[DTYPEc_t, ndim=3] datas,datat
        cdef np.ndarray[DTYPEc_t, ndim=3] datatemp,datatemp2

        sim = self.sim

        state_fft = sim.state.state_fft

        nk = state_fft.nb_variables
        n0 = sim.oper.nK0_loc
        n1 = sim.oper.nK1_loc

        exact, exact2 =self.exact_coefs.get_updated_coefs()

        tendencies_fft_1 = sim.tendencies_non_diff()

        ## alternativelly, this
        # state_fft_temp = (self.state_fft + dt/6*tendencies_fft_1)*exact
        # state_fft_np12_approx1 = (self.state_fft + dt/2*tendencies_fft_1)*exact2
        ## or this (slightly faster...) 

        datas = state_fft.data
        datat = tendencies_fft_1.data

        state_fft_temp = SetOfVariables(otherEV=state_fft)
        datatemp = state_fft_temp.data

        state_fft_np12_approx1 = SetOfVariables(otherEV=state_fft)
        datatemp2 = state_fft_np12_approx1.data

        for ik in xrange(nk):
            for i0 in xrange(n0):
                for i1 in xrange(n1):
                    datatemp[ik, i0, i1] = (
                        datas[ik, i0, i1] + dt/6*datat[ik, i0, i1]
                        )*exact[i0, i1]
                    datatemp2[ik, i0, i1] = (
                        datas[ik, i0, i1] + dt/2*datat[ik, i0, i1]
                        )*exact2[i0, i1]

        del(tendencies_fft_1)
        tendencies_fft_2 = sim.tendencies_non_diff(state_fft_np12_approx1)
        del(state_fft_np12_approx1)

        ## alternativelly, this
        # state_fft_temp += dt/3*exact2*tendencies_fft_2
        # state_fft_np12_approx2 = exact2*self.state_fft + dt/2*tendencies_fft_2
        ## or this (slightly faster...) 

        datat = tendencies_fft_2.data

        state_fft_np12_approx2 = SetOfVariables(otherEV=state_fft)
        datatemp2 = state_fft_np12_approx2.data

        for ik in xrange(nk):
            for i0 in xrange(n0):
                for i1 in xrange(n1):
                    datatemp[ik, i0, i1] = (
                        datatemp[ik, i0, i1] 
                        + dt/3*exact2[i0, i1]*datat[ik, i0, i1]
                        )
                    datatemp2[ik, i0, i1] = (
                        exact2[i0, i1]*datas[ik, i0, i1] 
                        +  dt/2*datat[ik, i0, i1]
                        )

        del(tendencies_fft_2)
        tendencies_fft_3 = sim.tendencies_non_diff(state_fft_np12_approx2)
        del(state_fft_np12_approx2)

        ## alternativelly, this
        # state_fft_temp += dt/3*exact2*tendencies_fft_3
        # state_fft_np1_approx = exact*self.state_fft + dt*exact2*tendencies_fft_3
        ## or this (slightly faster...) 

        datat = tendencies_fft_3.data

        state_fft_np1_approx = SetOfVariables(otherEV=state_fft)
        datatemp2 = state_fft_np1_approx.data

        for ik in xrange(nk):
            for i0 in xrange(n0):
                for i1 in xrange(n1):
                    datatemp[ik, i0, i1] = (
                        datatemp[ik, i0, i1] 
                        + dt/3*exact2[i0, i1]*datat[ik, i0, i1]
                        )
                    datatemp2[ik, i0, i1] = (
                        exact[i0, i1]*datas[ik, i0, i1] 
                        +  dt*exact2[i0, i1]*datat[ik, i0, i1]
                        )

        del(tendencies_fft_3)
        tendencies_fft_4 = sim.tendencies_non_diff(state_fft_np1_approx)
        del(state_fft_np1_approx)

        ## alternativelly, this
        # self.state_fft = state_fft_temp + dt/6*tendencies_fft_4
        ## or this (slightly faster... may be not...) 

        datat = tendencies_fft_4.data

        for ik in xrange(nk):
            for i0 in xrange(n0):
                for i1 in xrange(n1):
                    datas[ik, i0, i1] = (
                        datatemp[ik, i0, i1]
                        + dt/6*datat[ik, i0, i1]
                        )







class TimeSteppingExactLinCython(object):


    def _time_step_RK2(self):
        raise Exception('This method should be written... '
                        'You can use RK4 instead.')


    @cython.boundscheck(False)
    @cython.wraparound(False)
    def _time_step_RK4(self):
        """Advance in time *sim.state.state_fft* with the Runge-Kutta 4 method.

        See :ref:`the pure python RK4 function <rk4timescheme>` for the
        presentation of the time scheme.

        For this function, the coefficient :math:`\sigma` is complex.

        """
        cdef double dt = self.deltat
        cdef Py_ssize_t i0, i1, ik, nk, n0, n1
        cdef np.ndarray[DTYPEc_t, ndim=3] exact, exact2
        cdef np.ndarray[DTYPEc_t, ndim=3] datas, datat
        cdef np.ndarray[DTYPEc_t, ndim=3] datatemp, datatemp2

        sim = self.sim

        state_fft = sim.state.state_fft

        nk = state_fft.nb_variables
        n0 = sim.oper.nK0_loc
        n1 = sim.oper.nK1_loc

        exact, exact2 =self.exact_coefs.get_updated_coefs()

        tendencies_fft_1 = sim.tendencies_non_diff()

        ## alternativelly, this
        # state_fft_temp = (self.state_fft + dt/6*tendencies_fft_1)*exact
        # state_fft_np12_approx1 = (self.state_fft + dt/2*tendencies_fft_1)*exact2
        ## or this (slightly faster...) 

        datas = state_fft.data
        datat = tendencies_fft_1.data

        state_fft_temp = SetOfVariables(otherEV=state_fft)
        datatemp = state_fft_temp.data

        state_fft_np12_approx1 = SetOfVariables(otherEV=state_fft)
        datatemp2 = state_fft_np12_approx1.data

        for ik in xrange(nk):
            for i0 in xrange(n0):
                for i1 in xrange(n1):
                    datatemp[ik, i0, i1] = (
                        datas[ik, i0, i1] + dt/6*datat[ik, i0, i1]
                        )*exact[ik, i0, i1]
                    datatemp2[ik, i0, i1] = (
                        datas[ik, i0, i1] + dt/2*datat[ik, i0, i1]
                        )*exact2[ik, i0, i1]

        del(tendencies_fft_1)
        tendencies_fft_2 = sim.tendencies_non_diff(state_fft_np12_approx1)
        del(state_fft_np12_approx1)

        ## alternativelly, this
        # state_fft_temp += dt/3*exact2*tendencies_fft_2
        # state_fft_np12_approx2 = exact2*self.state_fft + dt/2*tendencies_fft_2
        ## or this (slightly faster...) 

        datat = tendencies_fft_2.data

        state_fft_np12_approx2 = SetOfVariables(otherEV=state_fft)
        datatemp2 = state_fft_np12_approx2.data

        for ik in xrange(nk):
            for i0 in xrange(n0):
                for i1 in xrange(n1):
                    datatemp[ik, i0, i1] = (
                        datatemp[ik, i0, i1] 
                        + dt/3*exact2[ik, i0, i1]*datat[ik, i0, i1]
                        )
                    datatemp2[ik, i0, i1] = (
                        exact2[ik, i0, i1]*datas[ik, i0, i1] 
                        +  dt/2*datat[ik, i0, i1]
                        )

        del(tendencies_fft_2)
        tendencies_fft_3 = sim.tendencies_non_diff(state_fft_np12_approx2)
        del(state_fft_np12_approx2)

        ## alternativelly, this
        # state_fft_temp += dt/3*exact2*tendencies_fft_3
        # state_fft_np1_approx = exact*self.state_fft + dt*exact2*tendencies_fft_3
        ## or this (slightly faster...) 

        datat = tendencies_fft_3.data

        state_fft_np1_approx = SetOfVariables(otherEV=state_fft)
        datatemp2 = state_fft_np1_approx.data

        for ik in xrange(nk):
            for i0 in xrange(n0):
                for i1 in xrange(n1):
                    datatemp[ik, i0, i1] = (
                        datatemp[ik, i0, i1] 
                        + dt/3*exact2[ik, i0, i1]*datat[ik, i0, i1]
                        )
                    datatemp2[ik, i0, i1] = (
                        exact[ik, i0, i1]*datas[ik, i0, i1] 
                        +  dt*exact2[ik, i0, i1]*datat[ik, i0, i1]
                        )

        del(tendencies_fft_3)
        tendencies_fft_4 = sim.tendencies_non_diff(state_fft_np1_approx)
        del(state_fft_np1_approx)

        ## alternativelly, this
        # self.state_fft = state_fft_temp + dt/6*tendencies_fft_4
        ## or this (slightly faster... may be not...) 

        datat = tendencies_fft_4.data

        for ik in xrange(nk):
            for i0 in xrange(n0):
                for i1 in xrange(n1):
                    datas[ik, i0, i1] = (
                        datatemp[ik, i0, i1]
                        + dt/6*datat[ik, i0, i1]
                        )

