# -*- coding: utf-8 -*-
"""
===========================================================================
Nanotube structure generators (:mod:`sknano.nanogen._nanotube_generators`)
===========================================================================

.. currentmodule:: sknano.nanogen._nanotube_generators

.. todo::

   Add methods to perform fractional translation and cartesian translation
   before structure generation.

.. todo::

   Handle different units and perform unit conversions for output coordinates.

.. todo::

   Replace arrays with class attributes for readability.

"""
from __future__ import division, print_function, absolute_import
__docformat__ = 'restructuredtext'

import copy
#import itertools
#import sys
import warnings
warnings.filterwarnings('ignore')  # to suppress the Pint UnicodeWarning

#from pint import UnitRegistry
#ureg = UnitRegistry()
#Qty = ureg.Quantity

import numpy as np

from ..chemistry import Atom, Atoms
from ..structure_io import DATAWriter, XYZWriter, default_structure_format, \
    supported_structure_formats
from ..tools import plural_word_check, rotation_matrix
from ..tools.refdata import CCbond

from ._nanotubes import Nanotube
from ._structure_generator import StructureGenerator

__all__ = ['NanotubeGenerator', 'MWNTGenerator']


class NanotubeGenerator(Nanotube, StructureGenerator):
    u"""Class for generating nanotube structures.

    Parameters
    ----------
    n, m : int
        Chiral indices defining the nanotube chiral vector
        :math:`\\mathbf{C}_{h} = n\\mathbf{a}_{1} + m\\mathbf{a}_{2} = (n, m)`.
    nz : int, optional
        Number of repeat unit cells in the :math:`z` direction, along
        the *length* of the nanotube.
    element1, element2 : {str, int}, optional
        Element symbol or atomic number of basis
        :py:class:`~sknano.chemistry.Atoms` 1 and 2
    bond : float, optional
        :math:`\\mathrm{a}_{\\mathrm{CC}} =` distance between
        nearest neighbor atoms. Must be in units of **Angstroms**.
    Lz : float, optional
        Length of nanotube in units of **nanometers**.
        Overrides the `nz` value.

        .. versionadded:: 0.2.5

    tube_length : float, optional
        Length of nanotube in units of **nanometers**.
        Overrides the `nz` value.

        .. deprecated:: 0.2.5
           Use `Lz` instead

    fix_Lz : bool, optional
        Generate the nanotube with length as close to the specified
        :math:`L_z` as possible. If `True`, then
        non integer :math:`n_z` cells are permitted.

        .. versionadded:: 0.2.6

    autogen : bool, optional
        if `True`, automatically call
        :py:meth:`~NanotubeGenerator.generate_unit_cell`,
        followed by :py:meth:`~NanotubeGenerator.generate_structure_data`.
    verbose : bool, optional
        if `True`, show verbose output

    Examples
    --------
    First, load the :py:class:`~sknano.nanogen.NanoGenerator` class.

    >>> from sknano.nanogen import NanotubeGenerator

    Now let's generate a :math:`\\mathbf{C}_{\\mathrm{h}} = (10, 5)`
    SWCNT unit cell.

    >>> nt = NanotubeGenerator(n=10, m=5)
    >>> nt.save_data(fname='10,5_unit_cell.xyz')

    The rendered structure looks like (orhographic view):

    .. image:: /images/10,5_unit_cell_orthographic_view.png

    and the perspective view:

    .. image:: /images/10,5_unit_cell_perspective_view.png

    """

    def __init__(self, n=int, m=int, nx=1, ny=1, nz=1,
                 element1='C', element2='C',
                 bond=CCbond, Lx=None, Ly=None, Lz=None,
                 tube_length=None, fix_Lz=False,
                 autogen=True, verbose=False):

        if tube_length is not None and Lz is None:
            Lz = tube_length

        super(NanotubeGenerator, self).__init__(
            n=n, m=m, nx=nx, ny=ny, nz=nz,
            element1=element1, element2=element2,
            bond=bond, Lx=Lx, Ly=Ly, Lz=Lz, fix_Lz=fix_Lz,
            with_units=False, verbose=verbose)

        if autogen:
            self.generate_unit_cell()
            self.generate_structure_data()

    def generate_unit_cell(self):
        """Generate the nanotube unit cell."""
        eps = 0.01
        n = self._n
        m = self._m
        bond = self._bond
        M = self._M
        T = self._T
        N = self._N
        rt = self._rt
        e1 = self._element1
        e2 = self._element2
        verbose = self._verbose

        aCh = Nanotube.compute_chiral_angle(n=n, m=m, rad2deg=False)

        tau = M * T / N
        dtau = bond * np.sin(np.pi / 6 - aCh)

        psi = 2 * np.pi / N
        dpsi = bond * np.cos(np.pi / 6 - aCh) / rt

        if verbose:
            print('dpsi: {}'.format(dpsi))
            print('dtau: {}\n'.format(dtau))

        self._unit_cell = Atoms()

        for i in xrange(1, N + 1):
            x1 = rt * np.cos(i * psi)
            y1 = rt * np.sin(i * psi)
            z1 = i * tau

            while z1 > T - eps:
                z1 -= T

            atom1 = Atom(e1, x=x1, y=y1, z=z1)
            atom1.rezero_coords()

            if verbose:
                print('Basis Atom 1:\n{}'.format(atom1))

            self._unit_cell.append(atom1)

            x2 = rt * np.cos(i * psi + dpsi)
            y2 = rt * np.sin(i * psi + dpsi)
            z2 = i * tau - dtau
            while z2 > T - eps:
                z2 -= T

            atom2 = Atom(e2, x=x2, y=y2, z=z2)
            atom2.rezero_coords()

            if verbose:
                print('Basis Atom 2:\n{}'.format(atom2))

            self._unit_cell.append(atom2)

    def generate_structure_data(self):
        """Generate structure data."""
        self._structure_atoms = Atoms()
        for nz in xrange(int(np.ceil(self._nz))):
            dr = np.array([0.0, 0.0, nz * self.T])
            for uc_atom in self._unit_cell:
                nt_atom = Atom(uc_atom.symbol)
                nt_atom.r = uc_atom.r + dr
                self._structure_atoms.append(nt_atom)

    def save_data(self, fname=None, structure_format=None,
                  rotation_angle=None, rot_axis=None, deg2rad=True,
                  center_CM=True):
        """Save structure data.

        Parameters
        ----------
        fname : {None, str}, optional
            file name string
        structure_format : {None, str}, optional
            chemical file format of saved structure data. Must be one of:

                - xyz
                - data

            If `None`, then guess based on `fname` file extension or
            default to `xyz` format.
        rotation_angle : {None, float}, optional
            Angle of rotation
        rot_axis : {'x', 'y', 'z'}, optional
            Rotation axis
        deg2rad : bool, optional
            Convert `rotation_angle` from degrees to radians.
        center_CM : bool, optional
            Center center-of-mass on origin.

        """
        if (fname is None and structure_format not in
                supported_structure_formats) or \
                (fname is not None and not
                    fname.endswith(supported_structure_formats) and
                    structure_format not in supported_structure_formats):
            structure_format = default_structure_format

        if fname is None:
            chirality = '{}{}r'.format('{}'.format(self._n).zfill(2),
                                       '{}'.format(self._m).zfill(2))
            if self._assume_integer_unit_cells:
                nz = ''.join(('{}'.format(self._nz),
                              plural_word_check('cell', self._nz)))
            else:
                nz = ''.join(('{:.2f}'.format(self._nz),
                              plural_word_check('cell', self._nz)))
            fname_wordlist = (chirality, nz)
            fname = '_'.join(fname_wordlist)
            fname += '.' + structure_format
        else:
            if fname.endswith(supported_structure_formats) and \
                    structure_format is None:
                for ext in supported_structure_formats:
                    if fname.endswith(ext):
                        structure_format = ext
                        break
            else:
                if structure_format is None or \
                        structure_format not in supported_structure_formats:
                    structure_format = default_structure_format

        self._fname = fname

        if center_CM:
            self._structure_atoms.center_CM()

        if self._L0 is not None and self._fix_Lz:
            self._structure_atoms.clip_bounds(
                abs_limit=(10 * self._L0 + 0.5) / 2, r_indices=[2])

        if rotation_angle is not None:
            R_matrix = rotation_matrix(rotation_angle,
                                       rot_axis=rot_axis,
                                       deg2rad=deg2rad)
            self._structure_atoms.rotate(R_matrix)

        if structure_format == 'data':
            DATAWriter.write(fname=self._fname, atoms=self._structure_atoms)
        else:
            XYZWriter.write(fname=self._fname, atoms=self._structure_atoms)


class MWNTGenerator(NanotubeGenerator):
    u"""Class for generating single, multi-walled nanotubes.

    .. versionchanged:: 0.2.20

       `MWNTGenerator` no longer generates MWNT *bundles*, only *single*
       MWNTs. To generate bundled MWNT structure data, use the
       `MWNTBundleGenerator` class.

    .. versionadded:: 0.2.8

    Parameters
    ----------
    n, m : int
        Chiral indices defining the nanotube chiral vector
        :math:`\\mathbf{C}_{h} = n\\mathbf{a}_{1} + m\\mathbf{a}_{2} = (n, m)`.
    nx, ny, nz : int, optional
        Number of repeat unit cells in the :math:`x, y, z` dimensions.
    element1, element2 : {str, int}, optional
        Element symbol or atomic number of basis
        :py:class:`~sknano.chemistry.Atoms` 1 and 2
    bond : float, optional
        :math:`\\mathrm{a}_{\\mathrm{CC}} =` distance between
        nearest neighbor atoms. Must be in units of **Angstroms**.
    Lx, Ly, Lz : float, optional
        length of bundle in :math:`x, y, z` dimensions in **nanometers**.
        Overrides the :math:`n_x, n_y, n_z` cell values.
    fix_Lz : bool, optional
        Generate the nanotube with length as close to the specified
        :math:`L_z` as possible. If `True`, then
        non integer :math:`n_z` cells are permitted.
    max_shells : int, optional
        Maximum number of shells per MWNT.
    min_shell_diameter : float, optional
        Minimum shell diameter, in units of **Angstroms**.
    shell_spacing : float, optional
        Shell spacing in units of **Angstroms**. Default
        value is the van der Waals interaction distance of 3.4 Angstroms.
    inner_shell_Ch_type : {None, 'armchair', AC', 'zigzag', 'ZZ', 'achiral',
                           'chiral'}, optional
        If `None`, the chiralities of the inner shells are constrained only
        by their diameter and will be chosen randomly if more than one
        candidate chirality exists. If not `None`, then the inner
        shell chirality type will be added as a constraint.
    autogen : bool, optional
        if `True`, automatically call
        :py:meth:`~MWNTGenerator.generate_unit_cell`,
        followed by :py:meth:`~MWNTGenerator.generate_structure_data`.
    verbose : bool, optional
        if `True`, show verbose output

    Examples
    --------

    >>> from sknano.nanogen import MWNTGenerator
    >>> mwnt = MWNTGenerator(n=40, m=40, max_shells=5, Lz=1.0, fix_Lz=True)
    >>> mwnt.save_data()

    .. image:: /images/5shell_mwnt_4040_outer_Ch_1cellx1cellx4.06cells-01.png

    """
    def __init__(self, n=int, m=int, nx=1, ny=1, nz=1,
                 element1='C', element2='C', bond=CCbond,
                 Lx=None, Ly=None, Lz=None, fix_Lz=False, max_shells=None,
                 min_shell_diameter=0.0, shell_spacing=3.4,
                 inner_shell_Ch_type=None, autogen=True, verbose=False):

        super(MWNTGenerator, self).__init__(
            n=n, m=m, nx=nx, ny=ny, nz=nz, bond=bond, element1=element1,
            element2=element2, Lx=Lx, Ly=Ly, Lz=Lz, fix_Lz=fix_Lz,
            autogen=False, verbose=verbose)

        self._max_shells = max_shells
        if max_shells is None:
            self._max_shells = np.inf

        self._min_shell_diameter = min_shell_diameter
        self._shell_spacing = shell_spacing
        self._inner_shell_Ch_type = inner_shell_Ch_type

        self._Nshells_per_tube = 1
        self._Natoms_per_tube = 0

        if autogen:
            super(MWNTGenerator, self).generate_unit_cell()
            self.generate_structure_data()

    def _generate_unit_cell(self, n=int, m=int):
        """Generate the unit cell of a MWNT shell"""
        eps = 0.01
        bond = self._bond
        e1 = self._element1
        e2 = self._element2

        N = Nanotube.compute_N(n=n, m=m)
        aCh = Nanotube.compute_chiral_angle(n=n, m=m, rad2deg=False)
        rt = Nanotube.compute_rt(n=n, m=m, bond=bond, with_units=False)
        T = Nanotube.compute_T(n=n, m=m, bond=bond, with_units=False)

        tau = Nanotube.compute_tau(n=n, m=m, bond=bond, with_units=False)
        dtau = bond * np.sin(np.pi / 6 - aCh)

        psi = Nanotube.compute_psi(n=n, m=m)
        dpsi = bond * np.cos(np.pi / 6 - aCh) / rt

        unit_cell = Atoms()

        for i in xrange(1, N + 1):
            x1 = rt * np.cos(i * psi)
            y1 = rt * np.sin(i * psi)
            z1 = i * tau

            while z1 > T - eps:
                z1 -= T

            atom1 = Atom(e1, x=x1, y=y1, z=z1)
            atom1.rezero_coords()

            unit_cell.append(atom1)

            x2 = rt * np.cos(i * psi + dpsi)
            y2 = rt * np.sin(i * psi + dpsi)
            z2 = i * tau - dtau
            while z2 > T - eps:
                z2 -= T

            atom2 = Atom(e2, x=x2, y=y2, z=z2)
            atom2.rezero_coords()

            unit_cell.append(atom2)

        return unit_cell

    def generate_structure_data(self):
        """Generate structure data.

        .. todo::

           Load the diameter and chirality data from file instead of
           generating it every time.

        """
        dt = []
        Ch = []
        for n in xrange(0, 201):
            for m in xrange(0, 201):
                if (n <= 2 and m <= 2):
                    continue
                else:
                    dt.append(Nanotube.compute_dt(n=n, m=m, bond=self._bond))
                    Ch.append((n, m))
        dt = np.asarray(dt)
        Ch = np.asarray(Ch)

        self._min_shell_diameter = max(self._min_shell_diameter, dt.min())

        super(MWNTGenerator, self).generate_structure_data()

        swnt0 = copy.deepcopy(self._structure_atoms)
        self._structure_atoms = Atoms(atoms=swnt0, deepcopy=True)
        self._structure_atoms.center_CM()

        Lzmin = self._Lz
        next_dt = self._dt - 2 * self._shell_spacing
        while self._Nshells_per_tube < self._max_shells and \
                next_dt >= self._min_shell_diameter:

            # get chiral indices for next_dt
            next_Ch_candidates = []
            delta_dt = 0.05
            while len(next_Ch_candidates) == 0 and \
                    next_dt >= self._min_shell_diameter:
                if self._inner_shell_Ch_type in ('AC', 'armchair'):
                    next_Ch_candidates = \
                        Ch[np.where(
                            np.logical_and(np.abs(dt - next_dt) <= delta_dt,
                                           Ch[:,0] == Ch[:,1]))]
                elif self._inner_shell_Ch_type in ('ZZ', 'zigzag'):
                    next_Ch_candidates = \
                        Ch[np.where(
                            np.logical_and(np.abs(dt - next_dt) <= delta_dt,
                                           np.logical_or(Ch[:,0] == 0,
                                                         Ch[:,1] == 0)))]
                elif self._inner_shell_Ch_type == 'achiral':
                    next_Ch_candidates = \
                        Ch[np.where(
                            np.logical_and(np.abs(dt - next_dt) <= delta_dt,
                                           np.logical_or(
                                               Ch[:,0] == Ch[:,1],
                                               np.logical_or(
                                                   Ch[:,0] == 0,
                                                   Ch[:,1] == 0))))]
                elif self._inner_shell_Ch_type == 'chiral':
                    next_Ch_candidates = \
                        Ch[np.where(
                            np.logical_and(np.abs(dt - next_dt) <= delta_dt,
                                           np.logical_and(
                                               Ch[:,0] != Ch[:,1],
                                               np.logical_and(
                                                   Ch[:,0] != 0,
                                                   Ch[:,1] != 1))))]
                else:
                    next_Ch_candidates = \
                        Ch[np.where(np.abs(dt - next_dt) <= delta_dt)]

                next_dt -= delta_dt
                #delta_dt += 0.05

            if len(next_Ch_candidates) > 0:
                n, m = next_Ch_candidates[
                    np.random.choice(np.arange(len(next_Ch_candidates)))]
                T = Nanotube.compute_T(n=n, m=m, bond=self._bond)
                Lz = Nanotube.compute_Lz(
                    n=n, m=m, bond=self._bond, nz=self._nz)
                Lzmin = min(Lzmin, Lz)

                # generate unit cell for new shell chiral indices
                shell_unit_cell = self._generate_unit_cell(n=n, m=m)

                if self._verbose:
                    print('new shell:\n'
                          'n, m = {}, {}\n'.format(n, m) +
                          'dt: {:.4f}\n'.format(next_dt) +
                          'shell_unit_cell.Natoms: {}\n'.format(
                              shell_unit_cell.Natoms))

                shell = Atoms()
                for nz in xrange(int(np.ceil(self._nz))):
                    dr = np.array([0.0, 0.0, nz * T])
                    for uc_atom in shell_unit_cell:
                        atom = Atom(uc_atom.symbol)
                        atom.r = uc_atom.r + dr
                        shell.append(atom)
                shell.center_CM()
                self._structure_atoms.extend(shell.atoms)
                self._Nshells_per_tube += 1
                next_dt -= 2 * self._shell_spacing
            else:
                break

        if self._L0 is not None and self._fix_Lz:
            self._structure_atoms.clip_bounds(
                abs_limit=(10 * self._L0 + 0.5) / 2, r_indices=[2])
        else:
            self._structure_atoms.clip_bounds(
                abs_limit=(10 * Lzmin + 0.5) / 2, r_indices=[2])

        self._Natoms_per_tube = self._structure_atoms.Natoms

        if self._verbose:
            print('Nshells_per_tube: {}'.format(self._Nshells_per_tube))
            print('Natoms_per_tube: {}'.format(self._Natoms_per_tube))

    def save_data(self, fname=None, structure_format=None,
                  rotation_angle=None, rot_axis=None, deg2rad=True,
                  center_CM=True):
        """Save structure data.

        Parameters
        ----------
        fname : {None, str}, optional
            file name string
        structure_format : {None, str}, optional
            chemical file format of saved structure data. Must be one of:

                - xyz
                - data

            If `None`, then guess based on `fname` file extension or
            default to `xyz` format.
        rotation_angle : {None, float}, optional
            Angle of rotation
        rot_axis : {'x', 'y', 'z'}, optional
            Rotation axis
        deg2rad : bool, optional
            Convert `rotation_angle` from degrees to radians.
        center_CM : bool, optional
            Center center-of-mass on origin.

        """
        if (fname is None and structure_format not in
                supported_structure_formats) or \
                (fname is not None and not
                    fname.endswith(supported_structure_formats) and
                    structure_format not in supported_structure_formats):
            structure_format = default_structure_format

        if fname is None:
            Nshells = '{}shell_mwnt'.format(self._Nshells_per_tube)
            chirality = '{}{}_outer_Ch'.format('{}'.format(self._n).zfill(2),
                                               '{}'.format(self._m).zfill(2))

            fname_wordlist = None
            if self._assume_integer_unit_cells:
                nz = ''.join(('{}'.format(self._nz),
                              plural_word_check('cell', self._nz)))
            else:
                nz = ''.join(('{:.2f}'.format(self._nz),
                              plural_word_check('cell', self._nz)))
            fname_wordlist = (Nshells, chirality, nz)
            fname = '_'.join(fname_wordlist)
            fname += '.' + structure_format

        else:
            if fname.endswith(supported_structure_formats) and \
                    structure_format is None:
                for ext in supported_structure_formats:
                    if fname.endswith(ext):
                        structure_format = ext
                        break
            else:
                if structure_format is None or \
                        structure_format not in supported_structure_formats:
                    structure_format = default_structure_format

        super(MWNTGenerator, self).save_data(
            fname=fname, structure_format=structure_format,
            rotation_angle=rotation_angle, rot_axis=rot_axis,
            deg2rad=deg2rad, center_CM=center_CM)
