

import numpy as np

from fluiddyn.util import mpi

from fluiddyn.simul.base.forcing import ForcingBase


class ForcingNS2D(ForcingBase):


    def compute_forcing_proportional(self):
        """Compute a forcing proportional to the flow."""
        rot_fft = self.sim.state.state_fft['rot_fft']
        rot_fft = self.oper.coarse_seq_from_fft_loc(rot_fft,
                                                    self.shapeK_loc_coarse)

        if mpi.rank > 0:
            Frot_fft = np.empty(self.shapeK_loc_coarse,
                                dtype=np.complex128)
        else:
            Frot_fft = self.normalize_forcingc_proportional(rot_fft)
            self.forcingc_fft['rot_fft'] = Frot_fft

        self.put_forcingc_in_forcing()

       ## verification
        self.verify_injection_rate()






    def compute_forcing_2nd_degree_eq(self):
        """compute a forcing normalize with a 2nd degree eq."""

        rot_fft = self.sim.state.state_fft['rot_fft']
        rot_fft = self.oper.coarse_seq_from_fft_loc(rot_fft,
                                                      self.shapeK_loc_coarse)

        if mpi.rank > 0:
            Frot_fft = np.empty(self.shapeK_loc_coarse,
                                dtype=np.complex128)
        else:
            Frot_fft = self.forcingc_raw_each_time()
            Frot_fft = self.normalize_forcingc_2nd_degree_eq(Frot_fft,
                                                             rot_fft)
            self.forcingc_fft['rot_fft'] = Frot_fft

        self.put_forcingc_in_forcing()

       ## verification
        # self.verify_injection_rate()




    def compute_forcing_particular_k(self):
        """compute a forcing decorralated from the flow"""
        rot_fft = self.sim.state.state_fft['rot_fft']
        rot_fft = self.oper.coarse_seq_from_fft_loc(rot_fft,
                                                      self.shapeK_loc_coarse)

        if mpi.rank > 0:
            Frot_fft = np.empty(self.shapeK_loc_coarse,
                                dtype=np.complex128)
        else:
            Frot_fft = self.forcingc_raw_each_time()
            Frot_fft = self.normalize_forcingc_part_k(Frot_fft,
                                                      rot_fft)
            self.forcingc_fft['rot_fft'] = Frot_fft

        self.put_forcingc_in_forcing()

        ## verification
        self.verify_injection_rate()



    def verify_injection_rate(self):
        """Verify injection rate."""
        Frot_fft = self.forcing_fft['rot_fft']
        rot_fft = self.sim.state.state_fft['rot_fft']

        PZ_forcing1 = abs(Frot_fft)**2/2*self.sim.time_stepping.deltat
        PZ_forcing2 = np.real(
            + Frot_fft.conj()*rot_fft
            + Frot_fft*rot_fft.conj())/2.
        PZ_forcing1 = self.oper.sum_wavenumbers(PZ_forcing1)
        PZ_forcing2 = self.oper.sum_wavenumbers(PZ_forcing2)
        if mpi.rank == 0:
            print 'PZ_f = {0:9.4e} ; PZ_f2 = {1:9.4e};'.format(
                PZ_forcing1+PZ_forcing2,
                PZ_forcing2)
