#
#------------------------------------------------------------------------------
# Copyright (c) 2013-2014, Christian Therien
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#------------------------------------------------------------------------------
#
# eea_int.py - This file is part of the PySptools package.
#

"""
PPI, NFINDR, ATGP, FIPPI classes
"""


import os.path as osp
import numpy as np
import pysptools.formatting as fmt
import eea


class Check(object):
    """ Validate inputs for the endmembers extraction algorithms classes """

    err1 = 'in {0}.extract(), M is not a numpy.array'
    err2 = 'in {0}.extract(), q is not of type int'
    err3 = 'in {0}.extract(), M have {1} dimension(s), expected 3 dimensions'
    err5 = 'in {0}.extract(), normalize have {1}, expected bool type'
    err10 = 'in {0} class, call extract before calling {1}'
    err11 = 'in {0} class, suffix is not of str type'
    err13 = 'in {0}.extract(), mask is not a numpy.array'
    err14 = 'in {0}.extract(), mask have {1} dimension(s), expected 2 dimensions'
    err15 = 'in {0}.plot(), length of info[\'wavelength\'] is {1}, expected length of {2}'
    err16 = 'in {0}.extract(), transform is not a numpy.array'
    err17 = 'in {0}.extract(), transform have {1} dimension(s), expected 3 dimensions'
    err18 = 'in {0}.extract(), q equal {1} and transform have {2} components, expected q == components - 1'

    def __init__(self, label):
        self.label = label

    def plot_input(self, E, suffix, info, method_name):
        if E == None:
            raise RuntimeError(self.err10.format(self.label, method_name))
        if type(suffix) is not str and suffix != None:
            raise TypeError(self.err11.format(self.label))
        if E.shape[1] != len(info['wavelength']):
            raise RuntimeError(self.err15.format(self.label, len(info['wavelength']), E.shape[1]))

    def extract_input(self, M, p, normalize):
        if type(M) is not np.ndarray:
            raise RuntimeError(self.err1.format(self.label))
        if type(p) is not int:
            raise RuntimeError(self.err2.format(self.label))
        if M.ndim != 3:
            raise RuntimeError(self.err3.format(self.label, M.ndim))
        if type(normalize) is not bool:
            raise TypeError(self.err5.format(self.label, type(normalize)))

    def transform(self, q, transform):
        if transform == None: return
        if type(transform) is not np.ndarray:
            raise RuntimeError(self.err16.format(self.label))
        if transform.ndim != 3:
            raise RuntimeError(self.err17.format(self.label, transform.ndim))
        if q-1 != transform.shape[2]:
            raise RuntimeError(self.err18.format(self.label, q, transform.shape[2]))

    def mask(self, mask):
        if not (type(mask) is np.ndarray or mask == None):
            raise RuntimeError(self.err13.format(self.label))
        if mask != None:
            if mask.ndim != 2:
                raise RuntimeError(self.err14.format(self.label, mask.ndim))


def _plot_end_members(path, E, info, utype, is_normalized, suffix=None):
    """ Plot a endmembers graph using matplotlib """
    import matplotlib.pyplot as plt
    if not('wavelength units' in info): info['wavelength units'] = 'Unknown'
    if not('wavelength' in info):
        raise 'in _plot_end_members, no wavelength defined'
    plt.ioff()
    plt.xlabel(info['z plot titles'][0])
    if is_normalized == True:
        plt.ylabel(info['z plot titles'][1]+' - normalized')
    else:
        plt.ylabel(info['z plot titles'][1])
    plt.title('Spectral Profile')
    plt.grid(True)
    n_graph = 1
    legend = []
    for i in xrange(E.shape[0]):
        plt.plot(info['wavelength'], E[i])
        legend.append('EM{0}'.format(str(i+1)))
        if (i+1) % 5 == 0 :
            plt.legend(legend, loc='upper left', framealpha=0.5)
            legend = []
            if suffix == None:
                fout = osp.join(path, 'emembers_{0}__{1}.png'.format(utype, n_graph))
            else:
                fout = osp.join(path, 'emembers_{0}__{1}_{2}.png'.format(utype, n_graph, suffix))
            try:
                plt.savefig(fout)
            except IOError:
                raise IOError('in _plot_end_members, no such file or directory: {0}'.format(path))
            n_graph += 1
            plt.clf()
            plt.xlabel(info['z plot titles'][0])
            if is_normalized == True:
                plt.ylabel(info['z plot titles'][1]+' - normalized')
            else:
                plt.ylabel(info['z plot titles'][1])
            plt.title('Spectral Profile')
            plt.grid(True)
    if E.shape[0] % 5 != 0:
        plt.legend(legend, loc='upper left', framealpha=0.5)
        if suffix == None:
            fout = osp.join(path, 'emembers_{0}__{1}.png'.format(utype, n_graph))
        else:
            fout = osp.join(path, 'emembers_{0}__{1}_{2}.png'.format(utype, n_graph, suffix))
        try:
            plt.savefig(fout)
        except IOError:
            raise IOError('in _plot_end_members, no such file or directory: {0}'.format(path))
    plt.clf()


def _display_end_members(U, info, utype, is_normalized, suffix):
    """ Display endmembers using matplotlib to the IPython Notebook. """
    import matplotlib.pyplot as plt
    if not('wavelength units' in info): info['wavelength units'] = 'Unknown'
    if not('wavelength' in info):
        raise 'in _plot_end_members, no wavelength defined'
    plt.xlabel(info['z plot titles'][0])
    if is_normalized == True:
        plt.ylabel(info['z plot titles'][1]+' - normalized')
    else:
        plt.ylabel(info['z plot titles'][1]+' - normalized')
    n_graph = 1
    legend = []
    for i in xrange(U.shape[0]):
        plt.plot(info['wavelength'], U[i])
        legend.append('EM{0}'.format(str(i+1)))
        if (i+1) % 5 == 0 :
            plt.legend(legend, loc='upper left', framealpha=0.5)
            legend = []
            plt.xlabel(info['z plot titles'][0])
            if is_normalized == True:
                plt.ylabel(info['z plot titles'][1]+' - normalized')
            else:
                plt.ylabel(info['z plot titles'][1])
            if suffix == None:
                plt.title('Spectral Profile {0} - {1}'.format(n_graph, utype))
            else:
                plt.title('Spectral Profile {0} - {1} - {2}'.format(n_graph, utype, suffix))
            plt.grid(True)
            plt.show()
            plt.clf()
            n_graph += 1
    if U.shape[0] % 5 != 0:
        plt.legend(legend, loc='upper left', framealpha=0.5)
        plt.xlabel(info['z plot titles'][0])
        if is_normalized == True:
            plt.ylabel(info['z plot titles'][1]+' - normalized')
        else:
            plt.ylabel(info['z plot titles'][1])
        if suffix == None:
            plt.title('Spectral Profile {0} - {1}'.format(n_graph, utype))
        else:
            plt.title('Spectral Profile {0} - {1} - {2}'.format(n_graph, utype, suffix))
        plt.grid(True)
        plt.show()
    plt.clf()


class PPI(object):
    """
    Performs the pixel purity index algorithm for endmember finding.
    """

    def __init__(self):
        self.check = Check('PPI')
        self.E = None
        self.w = None
        self.idx = None
        self.idx3D = None
        self.is_normalized = False

    def extract(self, M, q, numSkewers=10000, normalize=False):
        """
        Extract the endmembers.

        Parameters:
            M: `numpy array`
                A HSI cube (m x n x p).

            q: `int`
                Number of endmembers to find.

            numSkewers: `int [default 10000]`
                Number of "skewer" vectors to project data onto.
                In general, recommendation from the literature is 10000 skewers.

        Returns: `numpy array`
                Recovered endmembers (N x p).
        """
        self.check.extract_input(M, q, normalize)
        if normalize == True:
            M = fmt.normalize(M)
            self.is_normalized = True
        h, self.w, numBands = M.shape
        M = np.reshape(M, (self.w*h, M.shape[2]))
        self.E, self.idx = eea.PPI(M, q, numSkewers)
        self.idx3D = [(i // self.w, i % self.w) for i in self.idx]
        return self.E

    def get_idx(self):
        """
        Returns: `numpy array`
            Array of indices into the HSI cube corresponding to the
            induced endmembers
        """
        return self.idx3D

    def plot(self, path, info, suffix=None):
        """
        Plot the endmembers.

        Parameters:
            path: `string`
              The path where to put the plot.

            info: `dictionary`
                * info['wavelength'] : a wavelengths list (1D python list).
                * info['wavelength units'] : the x axis label, 'Unknown' if not specified.

            suffix: `string [default None]`
                Suffix to add to the file name.
        """
        self.check.plot_input(self.E, suffix, info, 'plot')
        _plot_end_members(path, self.E, info, 'PPI', self.is_normalized, suffix)

    def display(self, info, suffix=None):
        """
        Display the endmembers to a IPython Notebook.

        Parameters:
            info: `dictionary`
                * info['wavelength'] : a wavelengths list (1D python list).
                * info['wavelength units'] : the x axis label, 'Unknown' if not specified.

            suffix: `string [default None]`
                Suffix to add to the title.
        """
        self.check.plot_input(self.E, suffix, info, 'display')
        _display_end_members(self.E, info, 'PPI', self.is_normalized, suffix)


class NFINDR(object):
    """
    N-FINDR endmembers induction algorithm.
    """

    def __init__(self):
        self.check = Check('NFINDR')
        self.E = None
        self.Et = None
        self.w = None
        self.idx = None
        self.it = None
        self.idx3D = None
        self.is_normalized = False

    def extract(self, M, q, transform=None, maxit=None, normalize=False, ATGP_init=False, mask=None):
        """
        Extract the endmembers.

        Parameters:
            M: `numpy array`
                A HSI cube (m x n x p).

            q: `int`
                The number of endmembers to be induced.

            transform: `numpy array [default None]`
                The transformed 'M' cube by MNF (m x n x components). In this
                case the number of components must == q-1. If None, the built-in
                call to PCA is used to transform M in q-1 components.

            maxit: `int [default None]`
                The maximum number of iterations. Default is 3*p.

            normalize: `boolean [default False]`
                If True, M is normalized before doing the endmembers induction.

            ATGP_init: `boolean [default False]`
                Use ATGP to generate the first endmembers set instead
                of a random selection.

            mask: `numpy array [default None]`
                A binary mask, when *True* the corresponding signal is part of the
                endmembers search.

        Returns: `numpy array`
            Set of induced endmembers (N x p).

        References:
            Winter, M. E., "N-FINDR: an algorithm for fast autonomous spectral
            end-member determination in hyperspectral data", presented at the Imaging
            Spectrometry V, Denver, CO, USA, 1999, vol. 3753, pgs. 266-275.

        Note:
            The division by (factorial(p-1)) is an invariant for this algorithm,
            for this reason it is skipped.
        """
        self.check.extract_input(M, q, normalize)
        self.check.transform(q, transform)
        self.check.mask(mask)
        import nfindr
        if normalize == True:
            M = fmt.normalize(M)
            self.is_normalized = True
        h, self.w, numBands = M.shape
        M = np.reshape(M, (self.w*h, M.shape[2]))
        if transform != None:
            transform = np.reshape(transform, (self.w*h, transform.shape[2]))
        if mask != None:
            mask1 = np.reshape(mask, (self.w*h, 1))
            self.E, self.Et, self.idx, self.it = eea.mNFINDR(M, mask1, q, transform, maxit, ATGP_init)
        else:
            self.E, self.Et, self.idx, self.it = nfindr.NFINDR(M, q, transform, maxit, ATGP_init)
        self.idx3D = [(i // self.w, i % self.w) for i in self.idx]
        return self.E

    def get_idx(self):
        """
        Returns : numpy array
            Array of indices into the HSI cube corresponding to the
            induced endmembers
        """
        return self.idx3D

    def get_iterations(self):
        """
        Returns : int
            The number of iterations.
        """
        return self.it

    def get_endmembers_transform(self):
        return self.Et

    def plot(self, path, info, suffix=None):
        """
        Plot the endmembers.

        Parameters:
            path: `string`
                The path where to put the plot.

            info: `dictionary`
                * info['wavelength'] : a wavelengths list (1D python list).
                * info['wavelength units'] : the x axis label, 'Unknown' if not specified.

            suffix: `string [default None]`
                Suffix to add to the file name.
        """
        self.check.plot_input(self.E, suffix, info, 'plot')
        _plot_end_members(path, self.E, info, 'NFINDR', self.is_normalized, suffix)

    def display(self, info, suffix=None):
        """
        Display the endmembers to a IPython Notebook.

        Parameters:
            info: `dictionary`
                * info['wavelength'] : a wavelengths list (1D python list).
                * info['wavelength units'] : the x axis label, 'Unknown' if not specified.

            suffix: `string [default None]`
                Suffix to add to the title.
        """
        self.check.plot_input(self.E, suffix, info, 'display')
        _display_end_members(self.E, info, 'NFINDR', self.is_normalized, suffix)


class ATGP(object):
    """
    Automatic target generation process endmembers induction algorithm.
    """

    def __init__(self):
        self.check = Check('ATGP')
        self.E = None
        self.w = None
        self.idx = None
        self.idx3D = None
        self.is_normalized = False

    def extract(self, M, q, normalize=False, mask=None):
        """
        Extract the endmembers.

        Parameters:
            M: `numpy array`
                A HSI cube (m x n x p).

            q: `int`
                Number of endmembers to be induced (positive integer > 0).

            normalize: `boolean [default False]`
                Normalize M before unmixing.

            mask: `numpy array [default None]`
                A binary mask, if True the corresponding signal is part of the
                endmembers search.

        Returns: `numpy array`
            Set of induced endmembers (N x p).

        References:
            A. Plaza y C.-I. Chang, "Impact of Initialization on Design of Endmember
            Extraction Algorithms", Geoscience and Remote Sensing, IEEE Transactions on,
            vol. 44, no. 11, pgs. 3397-3407, 2006.
        """
        self.check.extract_input(M, q, normalize)
        self.check.mask(mask)
        if normalize == True:
            M = fmt.normalize(M)
            self.is_normalized = True
        h, self.w, numBands = M.shape
        M = np.reshape(M, (self.w*h, M.shape[2]))
        if mask != None:
            mask1 = np.reshape(mask, (self.w*h, 1))
            self.E, self.idx = eea.mATGP(M, q, mask1)
        else:
            self.E, self.idx = eea.ATGP(M, q)
        self.idx3D = [(i // self.w, i % self.w) for i in self.idx]
        return self.E

    def get_idx(self):
        """
        Returns: `numpy array`
            Array of indices into the HSI cube corresponding to the
            induced endmembers
        """
        return self.idx3D

    def plot(self, path, info, suffix=None):
        """
        Plot the endmembers.

        Parameters:
            path: `string`
                The path where to put the plot.

            info: `dictionary`
                * info['wavelength'] : a wavelengths list (1D python list).
                * info['wavelength units'] : the x axis label, 'Unknown' if not specified.

            suffix: `string [default None]`
                Suffix to add to the file name.
        """
        self.check.plot_input(self.E, suffix, info, 'plot')
        _plot_end_members(path, self.E, info, 'ATGP', self.is_normalized, suffix)

    def display(self, info, suffix=None):
        """
        Display the endmembers to a IPython Notebook.

        Parameters:
            info: `dictionary`
                * info['wavelength'] : a wavelengths list (1D python list).
                * info['wavelength units'] : the x axis label, 'Unknown' if not specified.

            suffix: `string [default None]`
                Suffix to add to the title.
        """
        self.check.plot_input(self.E, suffix, info, 'display')
        _display_end_members(self.E, info, 'ATGP', self.is_normalized, suffix)


class FIPPI(object):
    """
    Fast Iterative Pixel Purity Index (FIPPI) endmembers
    induction algorithm.
    """

    def __init__(self):
        self.check = Check('FIPPI')
        self.E = None
        self.w = None
        self.idx = None
        self.idx3D = None
        self.is_normalized = False

    def extract(self, M, q=None, maxit=None, normalize=False):
        """
        Extract the endmembers.

        Parameters:
            M: `numpy array`
                A HSI cube (m x n x p).

            q: `int [default None]`
                Number of endmembers to be induced, if None use
                HfcVd to determine the number of endmembers to induce.

            maxit: `int [default None]`
                Maximum number of iterations. Default = 3*q.

            normalize: `boolean [default False]`
                Normalize M before unmixing.

        Returns: `numpy array`
            Set of induced endmembers (N x p).

        References:
            Chang, C.-I., "A fast iterative algorithm for implementation of pixel purity index",
            Geoscience and Remote Sensing Letters, IEEE, vol. 3, no. 1, pags. 63-67, 2006.
        """
        self.check.extract_input(M, q, normalize)
        if normalize == True:
            M = fmt.normalize(M)
            self.is_normalized = True
        h, self.w, numBands = M.shape
        M = np.reshape(M, (self.w*h, M.shape[2]))
        self.E, self.idx = eea.FIPPI(M, q=q, maxit=maxit)
        self.idx3D = [(i // self.w, i % self.w) for i in self.idx]
        return self.E

    def get_idx(self):
        """
        Returns: `numpy array`
            Array of indices into the HSI cube corresponding to the
            induced endmembers.
        """
        return self.idx3D

    def plot(self, path, info, suffix=None):
        """
        Plot the endmembers.

        Parameters:
            path: `string`
                The path where to put the plot.

            info: `dictionary`
                * info['wavelength'] : a wavelengths list (1D python list).
                * info['wavelength units'] : the x axis label, 'Unknown' if not specified.

            suffix: `string [default None]`
                Suffix to add to the file name.
        """
        self.check.plot_input(self.E, suffix, info, 'plot')
        _plot_end_members(path, self.E, info, 'FIPPI', self.is_normalized, suffix)

    def display(self, info, suffix=None):
        """
        Display the endmembers to a IPython Notebook.

        Parameters:
            info: `dictionary`
                * info['wavelength'] : a wavelengths list (1D python list).
                * info['wavelength units'] : the x axis label, 'Unknown' if not specified.

            suffix: `string [default None]`
                Suffix to add to the title.
        """
        self.check.plot_input(self.E, suffix, info, 'display')
        _display_end_members(self.E, info, 'FIPPI', self.is_normalized, suffix)
