"""NS2D solver (:mod:`fluiddyn.simul.solvers.ns2d.solver`)
=========================================================



"""

from fluiddyn.simul.base.solver import SimulBase
from fluiddyn.simul.operators.setofvariables import SetOfVariables
from fluiddyn.simul.base.info_solver_params import InfoSolverBase


info_solver = InfoSolverBase(tag='solver')

ns2d = 'fluiddyn.simul.solvers.ns2d'
info_solver.module_name = ns2d+'.solver'
info_solver.class_name = 'Simul'
info_solver.short_name = 'NS2D'

classes = info_solver.classes

classes.State.module_name = ns2d+'.state'
classes.State.class_name = 'StateNS2D'

classes.InitFields.module_name = ns2d+'.init_fields'
classes.InitFields.class_name = 'InitFieldsNS2D'

classes.Output.module_name = ns2d+'.output'
classes.Output.class_name = 'Output'

classes.Forcing.module_name = ns2d+'.forcing'
classes.Forcing.class_name = 'ForcingNS2D'


info_solver.complete_with_classes()




class Simul(SimulBase):
    """Pseudo-spectral solver 2D incompressible Navier-Stokes equations.

    """

    @staticmethod
    def _complete_params_with_default(params):
        """This static method is used to complete the *params* container.
        """
        SimulBase._complete_params_with_default(params)
        attribs={'beta': 0.}
        params.set_attribs(attribs)


    def __init__(self, params):
        # the common initialization with the NS2D info_solver:
        super(Simul, self).__init__(params, info_solver)



    def tendencies_non_diff(self, state_fft=None):
        oper = self.oper
        fft2 = oper.fft2
        ifft2 = oper.ifft2

        if state_fft is None:
            rot_fft = self.state.state_fft['rot_fft']
            ux = self.state.state_phys['ux']
            uy = self.state.state_phys['uy']
        else:
            rot_fft = state_fft['rot_fft']
            ux_fft, uy_fft = oper.vecfft_from_rotfft(rot_fft)
            ux = ifft2(ux_fft)
            uy = ifft2(uy_fft)

        px_rot_fft, py_rot_fft = oper.gradfft_from_fft(rot_fft)
        px_rot = ifft2(px_rot_fft)
        py_rot = ifft2(py_rot_fft)

        Frot = -ux*px_rot - uy*(py_rot + self.params.beta)
        Frot_fft = fft2(Frot)
        oper.dealiasing(Frot_fft)

        # T_rot = np.real(Frot_fft.conj()*rot_fft
        #                + Frot_fft*rot_fft.conj())/2.
        # print ('sum(T_rot) = {0:9.4e} ; sum(abs(T_rot)) = {1:9.4e}'
        #       ).format(self.oper.sum_wavenumbers(T_rot),
        #                self.oper.sum_wavenumbers(abs(T_rot)))

        tendencies_fft = SetOfVariables(
            otherEV=self.state.state_fft,
            name_type_variables='tendencies_non_diff')

        tendencies_fft['rot_fft'] = Frot_fft

        if self.params.FORCING:
            tendencies_fft += self.forcing.forcing_fft

        return tendencies_fft













if __name__=="__main__":

    import numpy as np

    import fluiddyn as fld

    from fluiddyn.simul.base.info_solver_params import create_params

    params = create_params(info_solver)

    params.short_name_type_run = 'test'

    nh = 32
    Lh = 2*np.pi
    params.oper.nx = nh
    params.oper.ny = nh
    params.oper.Lx = Lh
    params.oper.Ly = Lh

    # params.oper.type_fft = 'FFTWPY'


    delta_x = params.oper.Lx/params.oper.nx
    params.nu_8 = 2.*10e-1*params.forcing.forcing_rate**(1./3)*delta_x**8

    params.time_stepping.t_end = 2.

    params.init_fields.type_flow_init = 'NOISE'

    params.output.periods_print.print_stdout = 0.25

    params.output.periods_save.phys_fields = 0.5
    params.output.periods_save.spectra = 0.5
    params.output.periods_save.spect_energy_budg = 0.5
    params.output.periods_save.increments = 0.5

    params.output.periods_plot.phys_fields = 0.

    params.output.phys_fields.field_to_plot = 'rot'

    params.output.spectra.has_to_plot = 1  # False
    params.output.spatial_means.has_to_plot = 1  # False
    params.output.spect_energy_budg.has_to_plot = 1  # False
    params.output.increments.has_to_plot = 1  # False


    sim = Simul(params)

    # sim.output.phys_fields.plot()
    sim.time_stepping.start()
    sim.output.phys_fields.plot()

    fld.show()
