"""
The module :mod:`output` supplies the classe :class:`outputBase` which
handles the output.

"""

from __future__ import print_function

import h5py
import matplotlib.pyplot as plt
import datetime
import os
import shutil
import numpy as np


import fluiddyn

from fluiddyn.util import mpi

from fluiddyn.io import FLUIDDYN_PATH_SIM, FLUIDDYN_PATH_SCRATCH

from fluiddyn.util.util import time_as_str, print_memory_usage

from fluiddyn.simul.util.util import load_params_simul


class OutputBase(object):
    """Handle the output."""

    @staticmethod
    def _complete_info_solver(info_solver):
        """Complete the ContainerXML info_solver.

        This is a static method!
        """
        info_solver.classes.Output.set_child('classes')
        classes = info_solver.classes.Output.classes

        classes.set_child(
            'PrintStdOut',
            attribs={'module_name': 'fluiddyn.simul.base.output.print_stdout',
                     'class_name': 'PrintStdOutBase'})

        classes.set_child(
            'PhysFields',
            attribs={'module_name': 'fluiddyn.simul.base.output.phys_fields',
                     'class_name': 'PhysFieldsBase'})


    @staticmethod
    def _complete_params_with_default(params):
        """This static method is used to complete the *params* container.
        """
        attribs={
            'period_show_plot': 1,
            'ONLINE_PLOT_OK': True}
        params.set_child('output', attribs=attribs)

        attribs={
            'phys_fields': 0}
        params.output.set_child('periods_save', attribs=attribs)

        attribs={
            'print_stdout': 0.5}
        params.output.set_child('periods_print', attribs=attribs)

        attribs={
            'phys_fields': 0}
        params.output.set_child('periods_plot', attribs=attribs)

        params.output.set_child('phys_fields',
                                attribs={'field_to_plot': 'ux'})



    def __init__(self, sim):
        params = sim.params
        self.sim = sim
        self.params = params.output

        self.SAVE = params.SAVE
        self.name_solver = sim.info.solver['short_name']

        # initialisation name_run and path_run
        list_for_name_run = self.create_list_for_name_run()
        list_for_name_run.append(time_as_str())
        self.name_run = '_'.join(list_for_name_run)

        self.sim.name_run = self.name_run

        if not params.NEW_DIR_RESULTS:
            try:
                self.path_run = params.path_run
            except AttributeError:
                params.NEW_DIR_RESULTS = True
                print('Strange: params.NEW_DIR_RESULTS == False'
                      ' but no params.path_run')

            # if SAVE, we verify the correspondence between the
            # resolution of the simulation and the resolution of the
            # previous simulation saved in this directory
            if self.SAVE:
                if mpi.rank==0:
                    try:
                        params_dir = load_params_simul(path_dir=self.path_run)
                    except:
                        raise ValueError(
                            'Strange, no info_simul.h5 in self.path_run')

                    if (params.oper.nx != params_dir.oper.nx
                            or params.oper.ny != params_dir.oper.ny):
                        params.NEW_DIR_RESULTS = True
                        print("""
Warning: params.NEW_DIR_RESULTS is False but the resolutions of the simulation
         and of the simulation in the directory self.path_run are different
         we put params.NEW_DIR_RESULTS = True""")
                if mpi.nb_proc > 1:
                    params.NEW_DIR_RESULTS = \
                        mpi.comm.bcast(params.NEW_DIR_RESULTS)


        if params.NEW_DIR_RESULTS:

            if FLUIDDYN_PATH_SCRATCH is not None:
                self.path_run = os.path.join(
                    FLUIDDYN_PATH_SCRATCH, self.sim.name_run)
            else:
                self.path_run = os.path.join(
                    FLUIDDYN_PATH_SIM, self.sim.name_run)

            if mpi.rank==0:
                params._set_attr_xml('path_run', self.path_run)
                if not os.path.exists(self.path_run):
                    os.makedirs(self.path_run)

        dico_classes = sim.info.solver.classes.Output.import_classes()

        PrintStdOut = dico_classes['PrintStdOut']
        self.print_stdout = PrintStdOut(self)



    def create_list_for_name_run(self):
        params = self.sim.params

        list_for_name_run = [self.name_solver]

        if len(params.short_name_type_run) > 0:
            list_for_name_run.append(params.short_name_type_run)

        if (params.oper.Lx/np.pi).is_integer():
            str_Lx = repr(int(params.oper.Lx/np.pi))+'pi'
        else:
            str_Lx = '{:.3f}'.format(params.oper.Lx).rstrip('0')
        if (params.oper.Ly/np.pi).is_integer():
            str_Ly = repr(int(params.oper.Ly/np.pi))+'pi'
        else:
            str_Ly = '{:.3f}'.format(params.oper.Ly).rstrip('0')
        list_for_name_run.append(('L='+str_Lx+'x'+str_Ly+'_{}x{}').format(
            params.oper.nx, params.oper.ny))

        # # Should go in the function of a child class (even without the 'try').
        # try:
        #     name_run += '_c={:.5g}_f={:.5g}'.format(
        #         np.sqrt(params.c2), params.f)
        # except AttributeError:
        #     pass

        return list_for_name_run


    def init_with_oper_and_state(self):
        sim = self.sim

        self.oper = sim.oper

        # Should go in the function of a child class
        self.vecfft_from_rotfft = self.oper.vecfft_from_rotfft
        self.sum_wavenumbers = self.oper.sum_wavenumbers
        self.fft2 = self.oper.fft2
        self.ifft2 = self.oper.ifft2
        self.rotfft_from_vecfft = self.oper.rotfft_from_vecfft

        if mpi.rank==0:
            # print info on the run
            specifications = (', '+sim.params.time_stepping.type_time_scheme+
                              ', '+self.oper.type_fft+' and ')
            if mpi.nb_proc==1:
                specifications = specifications+'sequenciel,\n'
            else:
                specifications += 'parallel ({0} proc.)\n'.format(mpi.nb_proc)
            self.print_stdout(
                '\nsolver '+self.name_solver+specifications+
                'nx = {0:6d} ; ny = {1:6d}\n'.format(
                    sim.params.oper.nx, sim.params.oper.ny)+
                'Lx = {0:6.2f} ; Ly = {1:6.2f}\n'.format(
                    sim.params.oper.Lx, sim.params.oper.Ly)+
                'path_run =\n'+self.path_run+'\n'+
                'type_flow_init = '+sim.params.init_fields.type_flow_init)


        if mpi.rank==0 and self.SAVE and sim.params.NEW_DIR_RESULTS:
            # save info on the run
            self.sim.info.solver.xml_save(
                path_file=self.path_run+'/info_solver.xml',
                comment=(
                    'This file has been created by'
                    ' the Python program FluidDyn '+fluiddyn.__version__+
                    '.\n\nIt should not be modified '
                    '(except for adding xml comments).'))

            self.sim.params.xml_save(
                path_file=self.path_run+'/params_simul.xml',
                comment=(
                    'This file has been created by'
                    ' the Python program FluidDyn '+fluiddyn.__version__+
                    '.\n\nIt should not be modified '
                    '(except for adding xml comments).'))


        if mpi.rank==0:
            plt.ion()
            self.print_stdout('Initialization outputs:')

        self.print_stdout.complete_init_with_state()

        dico_classes = sim.info.solver.classes.Output.import_classes()

        # This class has already been instantiated.
        dico_classes.pop('PrintStdOut')

        for Class in dico_classes.values():
            print(Class, Class._tag)
            self.__dict__[Class._tag] = Class(self)

        print_memory_usage(
            '\nMemory usage at the end of init. (equiv. seq.)')
        self.print_size_in_Mo(self.sim.state.state_fft, 'state_fft')








    def one_time_step(self):

        for k in self.params.periods_print.xml_attrib.keys():
            period = self.params.periods_print.__dict__[k]
            if period != 0:
                self.__dict__[k].online_print()

        for k in self.params.periods_plot.xml_attrib.keys():
            period = self.params.periods_plot.__dict__[k]
            if period != 0:
                self.__dict__[k].online_plot()

        if self.SAVE:
            for k in self.params.periods_save.xml_attrib.keys():
                period = self.params.periods_save.__dict__[k]
                if period != 0:
                    self.__dict__[k].online_save()




    def figure_axe(self, numfig=None, size_axe=None):
        if mpi.rank == 0:
            if size_axe is None:
                x_left_axe = 0.12
                z_bottom_axe = 0.1
                width_axe = 0.85
                height_axe = 0.84
                size_axe = [x_left_axe, z_bottom_axe,
                            width_axe, height_axe]
            if numfig is None:
                fig = plt.figure()
            else:
                fig = plt.figure(numfig)
                fig.clf()
            axe = fig.add_axes(size_axe)
            return fig, axe



    def end_of_simul(self, total_time):
        self.print_stdout(
            'Computation completed in {0:8.6g} s\n'.format(total_time)+
            'path_run =\n'+self.path_run)
        if self.SAVE:
            self.phys_fields.save()
        if mpi.rank == 0 and self.SAVE:
            self.print_stdout.close()

            for k in self.params.periods_save.xml_attrib.keys():
                period = self.params.periods_save.__dict__[k]
                if period != 0:
                    if hasattr(self.__dict__[k], 'close_file'):
                        self.__dict__[k].close_file()

        if (not self.path_run.startswith(FLUIDDYN_PATH_SIM)
                and mpi.rank == 0):
            new_path_run = os.path.join(FLUIDDYN_PATH_SIM, self.sim.name_run)
            print('move result directory in directory:\n'+new_path_run)
            shutil.move(self.path_run, FLUIDDYN_PATH_SIM)
            self.path_run = new_path_run


    def compute_energy(self):
        return 0.

    def print_size_in_Mo(self, arr, string=None):
        if string is None:
            string='Size of ndarray (equiv. seq.)'
        else:
            string = 'Size of '+string+' (equiv. seq.)'
        mem = arr.nbytes*1.e-6
        if mpi.nb_proc > 1:
            mem = mpi.comm.allreduce(mem, op=mpi.MPI.SUM)
        self.print_stdout(string.ljust(30)+': {0} Mo'.format(mem))































class SpecificOutput(object):
    """Small class for features useful for specific outputs"""

    def __init__(self, output, name_file=None,
                 period_save=None, has_to_plot=None,
                 dico_arrays_1time=None):

        self.period_save = period_save
        self.has_to_plot = has_to_plot

        sim = output.sim
        params = sim.params

        self.output = output
        self.sim = sim
        self.oper = sim.oper
        self.params = params

        if not params.output.ONLINE_PLOT_OK:
            self.has_to_plot = False

        self.period_show = params.output.period_show_plot
        self.t_last_show = 0.

        if name_file is not None:
            self.path_file = self.output.path_run+'/'+name_file
        else:
            self.init_path_files()

        if not output.SAVE:
            self.period_save = 0.

        if self.period_save == 0.:
            return

        self.init_files(dico_arrays_1time)
        if self.has_to_plot and mpi.rank == 0:
            self.init_online_plot()



    def init_files(self, dico_arrays_1time=None):
        if dico_arrays_1time is None:
            dico_arrays_1time = {}
        dico_results = self.compute()
        if mpi.rank == 0:
            if not os.path.exists(self.path_file):
                self.create_file_from_dico_arrays(
                    self.path_file, dico_results, dico_arrays_1time)
                self.nb_saved_times = 1
            else:
                with h5py.File(self.path_file, 'r') as f:
                    dset_times = f['times']
                    self.nb_saved_times = dset_times.shape[0]+1
                self.add_dico_arrays_to_file(self.path_file,
                                             dico_results)
        self.t_last_save = self.sim.time_stepping.t



    def online_save(self):
        """Save the values at one time. """
        tsim = self.sim.time_stepping.t
        if (tsim-self.t_last_save>=self.period_save):
            self.t_last_save= tsim
            dico_results = self.compute()
            if mpi.rank == 0:
                self.add_dico_arrays_to_file(self.path_file,
                                             dico_results)
                self.nb_saved_times += 1
                if self.has_to_plot:
                    self._online_plot(dico_results)
                    if (tsim-self.t_last_show>=self.period_show):
                        self.t_last_show = tsim
                        self.fig.canvas.draw()



    def create_file_from_dico_arrays(self, path_file,
                                     dico_arrays, dico_arrays_1time):
        if os.path.exists(path_file):
            print('file NOT created since it already exists!')
        elif mpi.rank == 0:
            with h5py.File(path_file, 'w') as f:
                f.attrs['date saving'] = str(datetime.datetime.now())
                f.attrs['name_solver'] = self.output.name_solver
                f.attrs['name_run'] = self.output.name_run

                self.sim.info.xml_to_hdf5(hdf5_parent=f)

                times = np.array([self.sim.time_stepping.t])
                f.create_dataset(
                    'times', data=times, maxshape=(None,))

                for k, v in dico_arrays_1time.iteritems():
                    f.create_dataset(k, data=v)

                for k, v in dico_arrays.iteritems():
                    v.resize([1, v.size])
                    f.create_dataset(
                        k, data=v, maxshape=(None, v.size))

    def add_dico_arrays_to_file(self, path_file, dico_arrays):
        if not os.path.exists(path_file):
            raise ValueError('can not add dico arrays in nonexisting file!')
        elif mpi.rank == 0:
            with h5py.File(path_file, 'r+') as f:
                dset_times = f['times']
                nb_saved_times = dset_times.shape[0]
                dset_times.resize((nb_saved_times+1,))
                dset_times[nb_saved_times] = self.sim.time_stepping.t
                for k, v in dico_arrays.iteritems():
                    dset_k = f[k]
                    dset_k.resize((nb_saved_times+1, v.size))
                    dset_k[nb_saved_times] = v

    def add_dico_arrays_to_open_file(self, f, dico_arrays, nb_saved_times):
        if mpi.rank == 0:
            dset_times = f['times']
            dset_times.resize((nb_saved_times+1,))
            dset_times[nb_saved_times] = self.sim.time_stepping.t
            for k, v in dico_arrays.iteritems():
                dset_k = f[k]
                dset_k.resize((nb_saved_times+1, v.size))
                dset_k[nb_saved_times] = v
