"""Time stepping (:mod:`fluiddyn.simul.base.time_stepping`)
===========================================================

.. currentmodule:: fluiddyn.simul.base.time_stepping

Provides:

.. autoclass:: TimeSteppingBase
   :members:
   :private-members:

.. autoclass:: TimeSteppingPseudoSpectralPurePython
   :members:
   :private-members:

.. autoclass:: TimeSteppingPseudoSpectral
   :members:
   :private-members:

.. autoclass:: TimeSteppingExactLin
   :members:
   :private-members:

.. autoclass:: TimeSteppingFiniteDiffPurePython
   :members:
   :private-members:

    .. todo:: Clean up and remove all non-universal things.


"""

import numpy as np
from time import time

from fluiddyn.util import mpi

import fluiddyn.simul.base.time_stepping_cy as time_stepping_cy
from fluiddyn.simul.operators.setofvariables import SetOfVariables




class TimeSteppingBase(object):
    """Universal time stepping class used for all solvers.


    """
    @staticmethod
    def _complete_params_with_default(params):
        """This static method is used to complete the *params* container.
        """
        attribs={'USE_T_END': True,
                 't_end': 10.,
                 'it_end': 10,
                 'USE_CFL': True,
                 'type_time_scheme': 'RK4',
                 'deltat0': 0.2}
        params.set_child('time_stepping', attribs=attribs)

    def __init__(self, sim):
        params = sim.params
        params_ts = params.time_stepping

        self.params = params
        self.sim = sim

        self.it = 0
        self.t = 0

        # Initialization time
        try:
            # We should change this. This is not general enough.
            omega_max = np.sqrt(params.c2*(
                params.kd2
                +(params.coef_dealiasing*sim.oper.kmax)**2))
        except AttributeError:
            omega_max = 0.
        if omega_max > 0:
            self.deltat_max = 0.78*np.pi/omega_max
        else:
            self.deltat_max = 0.2
        if params_ts.USE_CFL:
            if params_ts.type_time_scheme=='RK2':
                self.CFL = 0.4
            elif params_ts.type_time_scheme=='RK4':
                self.CFL = 1.0
            else:
                raise ValueError('Problem name time_scheme')
        else:
            self.deltat = params_ts.deltat0
        if params_ts.type_time_scheme=='RK2':
            self._time_step_RK = self._time_step_RK2
        elif params_ts.type_time_scheme=='RK4':
            self._time_step_RK = self._time_step_RK4
        else:
            raise ValueError('Problem name time_scheme')

        self.deltat = params_ts.deltat0
        self.init_freq_lin()












    def start(self):
        """Loop to run the function :func:`one_time_step`.

        If *self.USE_T_END* is true, run till ``t >= t_end``,
        otherwise run *self.it_end* time steps.
        """
        print_stdout = self.sim.output.print_stdout
        print_stdout(
            '*************************************\n'+
            'Beginning of the computation')
        if self.sim.output.SAVE:
            self.sim.output.phys_fields.save()
        time_begining_simul = time()
        if self.params.time_stepping.USE_T_END:
            print_stdout(
                '    compute until t = {0:10.6g}'.format(
                    self.params.time_stepping.t_end))
            while self.t < self.params.time_stepping.t_end:
                self.one_time_step()
        else:
            print_stdout(
                '    compute until it = {0:8d}'.format(
                    self.params.time_stepping.it_end))
            while self.it < self.params.time_stepping.it_end:
                self.one_time_step()
        total_time_simul = time() - time_begining_simul
        self.sim.output.end_of_simul(total_time_simul)





    def one_time_step(self):
        if self.params.time_stepping.USE_CFL:
            self._compute_time_increment_CLF()
        if self.params.FORCING:
            self.sim.forcing.compute()
        self.sim.output.one_time_step()
        self.one_time_step_computation()


    def _compute_time_increment_CLF(self):
        """Compute the time increment deltat with a CLF condition."""
        ux = self.sim.state('ux')
        uy = self.sim.state('uy')
        max_ux = abs(ux).max()
        max_uy = abs(uy).max()
        temp = (  max_ux/self.sim.oper.deltax
                  + max_uy/self.sim.oper.deltay)

        if mpi.nb_proc > 1:
            temp = mpi.comm.allreduce(temp, op=mpi.MPI.MAX)

        if temp>0:
            deltat_CFL = self.CFL/temp
        else:
            deltat_CFL = self.deltat_max

        maybe_new_dt = min(deltat_CFL, self.deltat_max)
        normalize_diff = abs(self.deltat-maybe_new_dt)/maybe_new_dt

        if normalize_diff > 0.02:
            self.deltat = maybe_new_dt







class TimeSteppingPseudoSpectralPurePython(TimeSteppingBase):
    """Time stepping class for pseudo-spectral solvers.

    """
    def one_time_step_computation(self):
        """One time step"""
        self._time_step_RK()
        self.sim.oper.dealiasing(self.sim.state.state_fft)
        self.sim.state.statephys_from_statefft()
        self.t += self.deltat
        self.it += 1
        if np.isnan(np.min(self.sim.state('ux'))):
            raise ValueError(
                'nan at it = {0}, t = {1:.4f}'.format(self.it, self.t))

    def init_freq_lin(self):
        f_d, f_d_hypo = self.compute_freq_diss()
        self.freq_lin = f_d + f_d_hypo
        self.exact_coefs = time_stepping_cy.ExactCoefsDiss(self)

    def compute_freq_diss(self):
        f_d = self.params.nu_8*self.sim.oper.K8
        if self.params.nu_4>0.:
            f_d += self.params.nu_4*self.sim.oper.K4
        if self.params.nu_2>0.:
            f_d += self.params.nu_2*self.sim.oper.K2
        if self.params.nu_m4>0.:
            f_d_hypo = self.params.nu_m4/self.sim.oper.K2_not0**2
            # mode K2 = 0 !
            if mpi.rank==0:
                f_d_hypo[0, 0] = (f_d_hypo[0, 1]+f_d_hypo[1, 0])/2
        else:
            f_d_hypo = 0.
        return f_d, f_d_hypo




    def _time_step_RK2(self):
        r"""Advance in time with the Runge-Kutta 2 method.

        .. _rk2timescheme:

        Notes
        -----

        .. |p| mathmacro:: \partial

        We consider an equation of the form

        .. math:: \p_t S = \sigma S + N(S),

        The Runge-Kutta 2 method computes an approximation of the
        solution after a time increment :math:`dt`. We denote the
        initial time :math:`t = 0`.

        - Approximation 1:

          .. math:: \p_t \log S = \sigma + \frac{N(S_0)}{S_0},

          Integrating from :math:`t` to :math:`t+dt/2`, it gives:

          .. |SA1halfdt| mathmacro:: S_{A1dt/2}

          .. math:: \SA1halfdt = (S_0 + N_0 dt/2) e^{\frac{\sigma dt}{2}}.


        - Approximation 2:

          .. math::
             \p_t \log S = \sigma
             + \frac{N(\SA1halfdt)}{ \SA1halfdt },

          Integrating from :math:`t` to :math:`t+dt` and retaining
          only the terms in :math:`dt^1` gives:

          .. math::
             S_{dtA2} = S_0 e^{\sigma dt}
             + N(\SA1halfdt) dt e^{\frac{\sigma dt}{2}}.

        """
        dt = self.deltat
        sigma = self.freq_lin
        diss = np.exp(-dt*sigma)
        diss2 = np.exp(-dt/2*sigma)

        sim = self.sim
        tendencies_fft_n = sim.tendencies_non_diff()
        state_fft_n12 = (sim.state.state_fft + dt/2*tendencies_fft_n)*diss2
        tendencies_fft_n12 = sim.tendencies_non_diff(state_fft_n12)
        sim.state.state_fft = (sim.state.state_fft*diss
                               + dt*tendencies_fft_n12*diss2)


    def _time_step_RK4(self):
        r"""Advance in time with the Runge-Kutta 4 method.

        .. _rk4timescheme:

        We consider an equation of the form

        .. math:: \p_t S = \sigma S + N(S),

        The Runge-Kutta 4 method computes an approximation of the
        solution after a time increment :math:`dt`. We denote the
        initial time as :math:`t = 0`. This time scheme uses 4
        approximations. Only the terms in :math:`dt^1` are retained.

        - Approximation 1:

          .. math:: \p_t \log S = \sigma + \frac{N(S_0)}{S_0},

          Integrating from :math:`t` to :math:`t+dt/2` gives:

          .. math:: \SA1halfdt = (S_0 + N_0 dt/2) e^{\sigma \frac{dt}{2}}.

          Integrating from :math:`t` to :math:`t+dt` gives:

          .. math:: S_{A1dt} = (S_0 + N_0 dt) e^{\sigma dt}.


        - Approximation 2:

          .. math::
             \p_t \log S = \sigma
             + \frac{N(\SA1halfdt)}{ \SA1halfdt },

          Integrating from :math:`t` to :math:`t+dt/2` gives:

          .. |SA2halfdt| mathmacro:: S_{A2 dt/2}

          .. math::
             \SA2halfdt = S_0 e^{\sigma \frac{dt}{2}}
             + N(\SA1halfdt) \frac{dt}{2}.

          Integrating from :math:`t` to :math:`t+dt` gives:

          .. math::
             S_{A2dt} = S_0 e^{\sigma dt}
             + N(\SA1halfdt) e^{\sigma \frac{dt}{2}} dt.


        - Approximation 3:

          .. math::
             \p_t \log S = \sigma
             + \frac{N(\SA2halfdt)}{ \SA2halfdt },

          Integrating from :math:`t` to :math:`t+dt` gives:

          .. math::
             S_{A3dt} = S_0 e^{\sigma dt}
             + N(\SA2halfdt) e^{\sigma \frac{dt}{2}} dt.

        - Approximation 4:

          .. math::
             \p_t \log S = \sigma
             + \frac{N(S_{A3dt})}{ S_{A3dt} },

          Integrating from :math:`t` to :math:`t+dt` gives:

          .. math::
             S_{A4dt} = S_0 e^{\sigma dt} + N(S_{A3dt}) dt.


        The final result is a pondered average of the results of 4
        approximations for the time :math:`t+dt`:

          .. math::
             \frac{1}{3} \left[
             \frac{1}{2} S_{A1dt}
             + S_{A2dt} + S_{A3dt}
             + \frac{1}{2} S_{A4dt}
             \right],

        which is equal to:

          .. math::
             S_0 e^{\sigma dt}
             + \frac{dt}{3} \left[
             \frac{1}{2} N(S_0) e^{\sigma dt}
             + N(\SA1halfdt) e^{\sigma \frac{dt}{2}}
             + N(\SA2halfdt) e^{\sigma \frac{dt}{2}}
             + \frac{1}{2} N(S_{A3dt})\right].

        """

        dt = self.deltat

        f_d = self.freq_lin
        diss2 = np.exp(-dt/2*f_d)
        diss = np.exp(-dt*f_d)

        sim = self.sim

        tendencies_fft_0 = sim.tendencies_non_diff()

        # based on approximation 1
        state_fft_temp = (sim.state.state_fft
                          + dt/6*tendencies_fft_0
                          )*diss
        state_fft_np12_approx1 = (sim.state.state_fft
                                  + dt/2*tendencies_fft_0
                                  )*diss2

        del(tendencies_fft_0)
        tendencies_fft_1 = sim.tendencies_non_diff(state_fft_np12_approx1)
        del(state_fft_np12_approx1)

        # based on approximation 2
        state_fft_temp += dt/3*diss2*tendencies_fft_1
        state_fft_np12_approx2 = (sim.state.state_fft*diss2
                                  + dt/2*tendencies_fft_1)

        del(tendencies_fft_1)
        tendencies_fft_2 = sim.tendencies_non_diff(state_fft_np12_approx2)
        del(state_fft_np12_approx2)

        # based on approximation 3
        state_fft_temp += dt/3*diss2*tendencies_fft_2
        state_fft_np1_approx = (sim.state.state_fft*diss
                                + dt*diss2*tendencies_fft_2)

        del(tendencies_fft_2)
        tendencies_fft_3 = sim.tendencies_non_diff(state_fft_np1_approx)
        del(state_fft_np1_approx)

        # result using the 4 approximations
        sim.state.state_fft = state_fft_temp + dt/6*tendencies_fft_3






class TimeSteppingPseudoSpectral(
        time_stepping_cy.TimeSteppingPseudoSpectralCython,
        TimeSteppingPseudoSpectralPurePython):
    """Time stepping class for pseudo-spectral solvers using Cython."""
    pass






class TimeSteppingExactLin(time_stepping_cy.TimeSteppingExactLinCython,
                           TimeSteppingPseudoSpectralPurePython):
    """Time stepping class for pseudo-spectral solvers using Cython."""

    def __init__(self, sim):
        super(TimeSteppingExactLin, self).__init__(sim)
        params = sim.params
        try:
            omega_max = np.sqrt(params.c2*(
                params.kd2
                +(params.coef_dealiasing*sim.oper.kmax)**2 ))
        except AttributeError:
            omega_max = 0.
        if omega_max > 0:
            self.deltat_max = 1.2*np.pi/omega_max
        else:
            self.deltat_max = 0.2


    def init_freq_lin(self):
        """Initialise the linear frequency array."""
        f_d, f_d_hypo = self.compute_freq_diss()
        freq_diss = f_d + f_d_hypo

        freq_lin = SetOfVariables(
            otherEV=self.sim.state.state_fft,
            name_type_variables='freq_lin_fft')

        for key in self.sim.state.state_fft.keys:
            freq_lin[key] = freq_diss + self.sim.freq_lin_exact(key)

        self.freq_lin = freq_lin

        self.exact_coefs = time_stepping_cy.ExactCoefsExactLin(self)












class TimeSteppingFiniteDiffPurePython(TimeSteppingBase):
    """
    Time stepping class for finite-difference solvers.

    """
    def one_time_step_computation(self):
        """One time step"""
        self._time_step_RK()
        self.t += self.deltat
        self.it += 1
        if np.isnan(np.min(self.sim.state('ux'))):
            raise ValueError(
                'nan at it = {0}, t = {1:.4f}'.format(self.it, self.t)
                )


    def _time_step_RK2(self):
        r"""Advance in time the variables with the Runge-Kutta 2 method.

        .. _rk2timeschemeFiniteDiff:

        Notes
        -----

        .. Look at Simson KTH documentation...
           (http://www.mech.kth.se/~mattias/simson-user-guide-v4.0.pdf)

        The Runge-Kutta 2 method computes an approximation of the
        solution after a time increment :math:`dt`. We denote the
        initial time :math:`t = 0`.

        For the finite difference schemes, We consider an equation of the form

        .. math:: \p_t S = L S + N(S),

        The linear term can be treated with an implicit method while
        the nonlinear term have to be treated with an explicit method
        (see for example `Explicit and implicit methods
        <http://en.wikipedia.org/wiki/Explicit_and_implicit_methods>`_).

        - Approximation 1:

          For the first step where the nonlinear term is approximated
          as :math:`N(S) \simeq N(S_0)`, we obtain

          .. math::
             \left( 1 - \frac{dt}{4} L \right) S_{A1dt/2}
             \simeq \left( 1 + \frac{dt}{4} L \right) S_0 + N(S_0)dt/2

          Once the right-hand side has been computed, a linear
          equation has to be solved. It is not efficient to invert the
          matrix :math:`1 + \frac{dt}{2} L` so other methods have to
          be used, as the `Thomas algorithm
          <http://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm>`_,
          or algorithms based on the LU or the QR decompositions.

        - Approximation 2:

            The nonlinear term is then approximated as :math:`N(S)
            \simeq N(S_{A1dt/2})`, which gives

            .. math::
               \left( 1 - \frac{dt}{2} L \right) S_{A2dt}
               \simeq \left( 1 + \frac{dt}{2} L \right) S_0 + N(S_{A1dt/2})dt

        """
        dt = self.deltat
        sim = self.sim

        # approximation 1:
        tendenciesNL_0 = sim.tendencies_non_linear()
        rhs_A1dt2 = self.right_hand_side(sim.state.state_phys,
                                         tendenciesNL_0, dt/2)
        S_A1dt2 = self.invert_to_get_solution(rhs_A1dt2)

        # approximation 2:
        tendenciesNL_1 = sim.tendencies_non_linear(S_A1dt2)
        rhs_A2dt = self.right_hand_side(S_A1dt2, tendenciesNL_1, dt)
        return self.invert_to_get_solution(rhs_A2dt)


    def right_hand_side(self, S, N, dt):
        return S + dt/2*self.sim.tendencies_linear(S) + dt*N

    def invert_to_get_solution(self, rhs):
        raise ValueError('NOT YET IMPLEMENTED.')
