""" """


import numpy as np

from fluiddyn.simul.base.output import OutputBase







class OutputBaseSW1l(OutputBase):

    @staticmethod
    def _complete_info_solver(info_solver):
        """Complete the ContainerXML info_solver.

        This is a static method!
        """
        info_solver.classes.Output.set_child('classes')
        classes = info_solver.classes.Output.classes

        base_name_mod = 'fluiddyn.simul.solvers.sw1l.output'

        classes.set_child(
            'PrintStdOut',
            attribs={'module_name': base_name_mod+'.print_stdout',
                     'class_name': 'PrintStdOutSW1l'})

        classes.set_child(
            'PhysFields',
            attribs={'module_name': 'fluiddyn.simul.base.output.phys_fields',
                     'class_name': 'PhysFieldsBase'})

        classes.set_child(
            'Spectra',
            attribs={'module_name': base_name_mod+'.spectra',
                     'class_name': 'SpectraSW1l'})

        classes.set_child(
            'SpatialMeans',
            attribs={'module_name': 'fluiddyn.simul.base.output.spatial_means',
                     'class_name': 'SpatialMeansSW1l'})

        attribs={
            'module_name': 'fluiddyn.simul.base.output.spect_energy_budget',
            'class_name': 'SpectralEnergyBudgetSW1l'}
        classes.set_child('SpectralEnergyBudget', attribs=attribs)

        attribs={
            'module_name': 'fluiddyn.simul.base.output.increments',
            'class_name': 'IncrementsSW1l'}
        classes.set_child('Increments', attribs=attribs)

        attribs={
            'module_name': 'fluiddyn.simul.base.output.prob_dens_func',
            'class_name': 'ProbaDensityFunc'}
        classes.set_child('ProbaDensityFunc', attribs=attribs)

        attribs={
            'module_name': 'fluiddyn.simul.base.output.time_signalsK',
            'class_name': 'TimeSignalsK'}
        classes.set_child('TimeSignalsK', attribs=attribs)


    @staticmethod
    def _complete_params_with_default(params):
        """This static method is used to complete the *params* container.
        """
        OutputBase._complete_params_with_default(params)

        params.output.periods_save.set_attribs({
            'spectra': 0,
            'spatial_means': 0.5,
            'spect_energy_budg': 0.5,
            'increments': 0.5,
            'pdf': 0.5,
            'time_signals_fft': False})

        params.output.phys_fields.field_to_plot = 'rot'

        params.output.set_child('spectra', attribs={'has_to_plot': False})
        params.output.set_child('spatial_means',
                                attribs={'has_to_plot': False})
        params.output.set_child('spect_energy_budg',
                                attribs={'has_to_plot': False})
        params.output.set_child('increments', attribs={'has_to_plot': False})
        params.output.set_child('pdf', attribs={'has_to_plot': False})
        params.output.set_child('time_signals_fft',
                                attribs={'nb_shells_time_sigK': 4,
                                         'nb_k_per_shell_time_sigK': 4})









    def linear_eigenmode_from_values_1k(self, ux_fft, uy_fft, eta_fft,
                                        kx, ky):
        div_fft = 1j*(kx*ux_fft + ky*uy_fft)
        rot_fft = 1j*(kx*uy_fft - ky*ux_fft)
        q_fft = rot_fft - self.sim.params.f*eta_fft
        k2 = kx**2+ky**2
        ageo_fft = self.sim.params.f*rot_fft/self.sim.params.c2 + k2*eta_fft
        return q_fft, div_fft, ageo_fft


    def omega_from_wavenumber(self, k):
        return np.sqrt(self.sim.params.f**2 + self.sim.params.c2*k**2)





    def compute_enstrophy_fft(self):
        rot_fft = self.sim.state('rot_fft')
        return np.abs(rot_fft)**2/2

    def compute_PV_fft(self):
        """Compute Ertel and Charney (QG) potential vorticity."""
        rot = self.sim.state('rot')
        eta = self.sim.state.state_phys['eta']
        ErtelPV_fft = self.fft2((self.sim.params.f+rot)/(1.+eta))
        rot_fft = self.sim.state('rot_fft')
        eta_fft = self.sim.state('eta_fft')
        CharneyPV_fft = rot_fft - self.sim.params.f*eta_fft
        return ErtelPV_fft, CharneyPV_fft

    def compute_PE_fft(self):
        ErtelPV_fft, CharneyPV_fft = self.compute_PV_fft()
        return (abs(ErtelPV_fft)**2/2,
                abs(CharneyPV_fft)**2/2)

    def compute_CharneyPE_fft(self):
        # compute Charney (QG) potential vorticity
        rot_fft = self.sim.state('rot_fft')
        eta_fft = self.sim.state('eta_fft')
        CharneyPV_fft = rot_fft - self.sim.params.f*eta_fft
        return abs(CharneyPV_fft)**2/2


    def compute_energies(self):
        energyK_fft, energyA_fft, energyKr_fft = self.compute_energies_fft()
        return (self.sum_wavenumbers(energyK_fft),
                self.sum_wavenumbers(energyA_fft),
                self.sum_wavenumbers(energyKr_fft))

    def compute_energiesKA(self):
        energyK_fft, energyA_fft = self.compute_energiesKA_fft()
        return (self.sum_wavenumbers(energyK_fft),
                self.sum_wavenumbers(energyA_fft))

    def compute_energy(self):
        energyK_fft, energyA_fft = self.compute_energiesKA_fft()
        return (self.sum_wavenumbers(energyK_fft)
                + self.sum_wavenumbers(energyA_fft))

    def compute_enstrophy(self):
        enstrophy_fft = self.compute_enstrophy_fft()
        return self.sum_wavenumbers(enstrophy_fft)









    def compute_lin_energies_fft(self):
        """Compute quadratic energies."""

        ux_fft = self.sim.state('ux_fft')
        uy_fft = self.sim.state('uy_fft')
        eta_fft = self.sim.state('eta_fft')

        # energy_quad_fft = 0.5*(np.abs(ux_fft)**2
        #                    + np.abs(uy_fft)**2
        #                    + self.sim.params.c2*np.abs(eta_fft)**2
        #                    )

        q_fft, div_fft, ageo_fft = \
            self.oper.qdafft_from_uxuyetafft(ux_fft, uy_fft, eta_fft)

        udx_fft, udy_fft = self.oper.vecfft_from_divfft(div_fft)
        energy_dlin_fft = 0.5*(np.abs(udx_fft)**2
                               + np.abs(udy_fft)**2
                               )

        ugx_fft, ugy_fft, etag_fft = self.oper.uxuyetafft_from_qfft(q_fft)
        energy_glin_fft = 0.5*(np.abs(ugx_fft)**2
                               + np.abs(ugy_fft)**2
                               + self.sim.params.c2*np.abs(etag_fft)**2)

        uax_fft, uay_fft, etaa_fft = self.oper.uxuyetafft_from_afft(ageo_fft)
        energy_alin_fft = 0.5*(np.abs(uax_fft)**2
                               + np.abs(uay_fft)**2
                               + self.sim.params.c2*np.abs(etaa_fft)**2)

        # print('sum energies_fft =\n',
        #        energy_dlin_fft + energy_glin_fft + energy_alin_fft
        #        - energy_quad_fft)


        return energy_glin_fft, energy_dlin_fft, energy_alin_fft
