#
#------------------------------------------------------------------------------
# Copyright (c) 2013-2014, Christian Therien
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#------------------------------------------------------------------------------
#
# cls_int.py - This file is part of the PySptools package.
#

"""
Spectral Angle Mapper class
Spectral Information Divergence class
Normalized cross correlation class
"""


import os.path as osp
import numpy as np
import pysptools.formatting as fmt
from . import cls


class Check(object):
    """ Validate inputs for the classifiers classes """

    err1 = 'in {0}.classify(), M is not a numpy.array'
    err2 = 'in {0}.classify(), E is not a numpy.array'
    err3 = 'in {0}.classify(), M have {1} dimension(s), expected 3 dimensions'
    err4 = 'in {0}.classify(), E have {1} dimension(s), expected 2 dimensions'
    err5 = 'in {0}.classify(), threshold have {1}, expected float or list type'
    err6 = 'in {0}.classify(), threshold have length {1}, expected length {2}'
    err7 = 'in {0}.classify(), threshold value is {1}, expected value between 0.0 and 1.0'
    err8 = 'in {0}.classify(), threshold value is {1} at index {2}, expected value between 0.0 and 1.0'
    err9 = 'in {0}.classify(), threshold indexing at {1} is out of range, expected value between 1 and {2}'
    err10 = 'in {0} class, call classify before calling {1}'
    err11 = 'in {0} class, suffix is not of str type'
    err12 = 'in {0}.classify(), the M spectrum length is different to the E spectrum length'

    def __init__(self, label):
        self.label = label

    def cmap_exist(self, cmap, method_name):
        if cmap == None:
            raise RuntimeError(self.err10.format(self.label, method_name))

    def plot_input(self, cmap, suffix, method_name):
        if cmap == None:
            raise RuntimeError(self.err10.format(self.label, method_name))
        if type(suffix) is not str and suffix != None:
            raise TypeError(self.err11.format(self.label))

    def index(self, idx, em_nbr):
        if idx == 'all': return
        if idx < 1 or idx > em_nbr:
            raise IndexError(self.err9.format(self.label, idx, em_nbr))

    def classify_input(self, M, E, threshold):
        if type(M) is not np.ndarray:
            raise RuntimeError(self.err1.format(self.label))
        if type(E) is not np.ndarray:
            raise RuntimeError(self.err2.format(self.label))
        if M.ndim != 3:
            raise RuntimeError(self.err3.format(self.label, M.ndim))
        if E.ndim != 1 and E.ndim != 2:
            raise RuntimeError(self.err4.format(self.label, E.ndim))
        if type(threshold) is not float and type(threshold) is not list:
            raise TypeError(self.err5.format(self.label, type(threshold)))
        if type(threshold) is list and len(threshold) != E.shape[0]:
            raise ValueError(self.err6.format(self.label, len(threshold), E.shape[0]))
        if type(threshold) is float:
            if threshold < 0.0 or threshold > 1.0:
                raise ValueError(self.err7.format(self.label, threshold))
        if type(threshold) is list:
            for i in range(len(threshold)):
                if threshold[i] < 0.0 or threshold[i] > 1.0:
                    raise ValueError(self.err8.format(self.label, threshold[i], i))
        if E.ndim == 1:
            if M.shape[2] != E.shape[0]:
                raise RuntimeError(self.err12.format(self.label))
        if E.ndim == 2:
            if M.shape[2] != E.shape[1]:
                raise RuntimeError(self.err12.format(self.label))


class Output(object):
    """ Add plot and display capacity to the classifiers classes.
    """

    def __init__(self, label):
        self.label = label

    def plot_single_map(self, path, cmap, dist_map, lib_idx, em_nbr, threshold, constrained, suffix=None):
##        """
##        Plot individual classified map. One for each spectrum.
##
##        Parameters
##            path : string
##              The path where to put the plot.
##
##            lib_idx : int
##                * A number between 1 and the number of spectra in the library.
##                * 'all', plot all the individual maps.
##
##            suffix : string
##              Add a suffix to the file name.
##
##        """
        if lib_idx == 'all':
            for signo in range(em_nbr):
                self._plot_single_map1(path, cmap, signo + 1, dist_map, threshold, constrained, suffix=suffix)
        else:
            self._plot_single_map1(path, cmap, lib_idx, dist_map, threshold, constrained, suffix=suffix)

    def _plot_single_map1(self, path, cmap, signo, dist_map, threshold, constrained, suffix=None):
        import matplotlib.pyplot as plt
        plt.ioff()
        grad = self.get_single_map(signo, cmap, dist_map, threshold, constrained)
        plt.imshow(grad)
        plt.set_cmap('spectral')
        cbar = plt.colorbar()
        cbar.set_ticks([])
        if suffix == None:
            fout = osp.join(path, 'cmap_{0}_{1}.png'.format(self.label, signo))
        else:
            fout = osp.join(path, 'cmap_{0}_{1}_{2}.png'.format(self.label, signo, suffix))
        try:
            plt.savefig(fout)
        except IOError:
            raise IOError('in classifiers.output, no such file or directory: {0}'.format(path))
        plt.clf()

    def display_single_map(self, cmap, dist_map, lib_idx, em_nbr, threshold, constrained, suffix=None):
        if lib_idx == 'all':
            for signo in range(em_nbr):
                self._display_single_map1(cmap, signo + 1, dist_map, threshold, constrained, suffix=suffix)
        else:
            self._display_single_map1(cmap, lib_idx, dist_map, threshold, constrained, suffix=suffix)

    def _display_single_map1(self, cmap, signo, dist_map, threshold, constrained, suffix=None):
        import matplotlib.pyplot as plt
        grad = self.get_single_map(signo, cmap, dist_map, threshold, constrained)
        plt.imshow(grad)
        plt.set_cmap('spectral')
        cbar = plt.colorbar()
        cbar.set_ticks([])
        if suffix == None:
            plt.title('{0} Class Map - EM{1}'.format(self.label, signo))
        else:
            plt.title('{0} Class Map - EM{1} - {2}'.format(self.label, signo, suffix))
        plt.show()
        plt.clf()

    def get_single_map(self, signo, cmap, dist_map, threshold, constrained, inverse_scale=True):
        if constrained == False:
            amin = np.min(dist_map[:,:,signo - 1])
            amax = np.max(dist_map[:,:,signo - 1])
            if type(threshold) is float:
                limit = amin + (amax - amin) * threshold
            if type(threshold) is list:
                limit = amin + (amax - amin) * threshold[signo - 1]
            if self.label == 'NormXCorr':
                grad = (dist_map[:,:,signo - 1] > limit) * dist_map[:,:,signo - 1]
            else:
                grad = (dist_map[:,:,signo - 1] < limit) * dist_map[:,:,signo - 1]
        if constrained == True:
            thresholded = cmap == signo
            grad = (dist_map[:,:,signo - 1] * thresholded)
        # inverse the scale for SAM and SID,
        # not needed for NormXCorr
        if inverse_scale == True:
            if self.label == 'SAM' or self.label == 'SID':
                for i in range(grad.shape[0]):
                    for j in range(grad.shape[1]):
                        if grad[i,j] != 0: grad[i,j] = 1 - grad[i,j]
        return grad

    def plot1(self, path, cmap, suffix):
        import matplotlib.pyplot as plt
        plt.ioff()
        plt.imshow(cmap)
        plt.set_cmap('spectral')
        cbar = plt.colorbar()
        cbar.set_ticks([])
        if suffix == None:
            fout = osp.join(path, 'cmap_{0}.png'.format(self.label))
        else:
            fout = osp.join(path, 'cmap_{0}_{1}.png'.format(self.label, suffix))
        try:
            plt.savefig(fout)
        except IOError:
            raise IOError('in classifiers.output, no such file or directory: {0}'.format(path))
        plt.clf()

    def display1(self, cmap, suffix):
        import matplotlib.pyplot as plt
        plt.imshow(cmap)
        plt.set_cmap('spectral')
        cbar = plt.colorbar()
        cbar.set_ticks([])
        if suffix == None:
            plt.title('{0} Class Map'.format(self.label))
        else:
            plt.title('{0} Class Map - {1}'.format(self.label, suffix))
        plt.show()
        plt.clf()

    def _custom_listed_color_map(self, name, N, firstBlack=False):
        """ add the black color in front of 'name' color """
        import matplotlib.cm as cm
        from matplotlib import colors
        mp = cm.datad[name]
        new_mp1 = {'blue': colors.makeMappingArray(N-1, mp['blue']),
                  'green': colors.makeMappingArray(N-1, mp['green']),
                  'red': colors.makeMappingArray(N-1, mp['red'])}
        new_mp2 = []
        new_mp2.extend(list(zip(new_mp1['red'], new_mp1['green'], new_mp1['blue'])))
        if firstBlack == True:
            new_mp2 = [(0,0,0)]+new_mp2 # the black color
        return colors.ListedColormap(new_mp2, N=N-1), new_mp2

    def plot_class_map(self, path, cmap, sig_nbr, colorMap, suffix, composite=None):
        import matplotlib.pyplot as plt
        import matplotlib.cm as cm
        from matplotlib import colors
        plt.ioff()
        # fallback on jet colormap
        #color = cm.jet
        if colorMap == 'Accent': color = cm.Accent
        elif colorMap == 'Dark2': color = cm.Dark2
        elif colorMap == 'Paired': color = cm.Paired
        elif colorMap == 'Pastel1': color = cm.Pastel1
        elif colorMap == 'Pastel2': color = cm.Pastel2
        elif colorMap == 'Set1': color = cm.Set1
        elif colorMap == 'Set2': color = cm.Set2
        elif colorMap == 'Set3': color = cm.Set3
        else:
            color = cm.jet
            colorMap = 'jet'

        bounds = list(range(sig_nbr+1))
        if composite != None:
            color, dummy = self._custom_listed_color_map(colorMap, len(bounds)+1, firstBlack=True)
        else:
            color, dummy = self._custom_listed_color_map(colorMap, len(bounds), firstBlack=False)
        norm = colors.BoundaryNorm(bounds, color.N)
        img = plt.imshow(cmap, cmap=color, interpolation=None, norm=norm)
        cbar = plt.colorbar(img, cmap=color, norm=norm, boundaries=bounds,
                            ticks=[x+0.5 for x in range(sig_nbr)])

        if composite != None:
            labels = ['None']
            labels.extend(composite)
            cbar.set_ticklabels(labels)
        else:
            cbar.set_ticklabels(list(range(1,sig_nbr+1)))
        img.get_axes().set_ylabel('spectrum #', rotation=270, labelpad=70)
        img.get_axes().yaxis.set_label_position("right")

        if suffix == None:
            fout = osp.join(path, 'cmap_{0}.png'.format(self.label))
        else:
            fout = osp.join(path, 'cmap_{0}_{1}.png'.format(self.label, suffix))
        try:
            plt.savefig(fout)
        except IOError:
            raise IOError('in classifiers.output, no such file or directory: {0}'.format(path))
        plt.clf()

    def display_class_map(self, cmap, sig_nbr, colorMap, suffix, composite=None):
        import matplotlib.pyplot as plt
        import matplotlib.cm as cm
        from matplotlib import colors
        # fallback on jet colormap
        #color = cm.jet
        if colorMap == 'Accent': color = cm.Accent
        elif colorMap == 'Dark2': color = cm.Dark2
        elif colorMap == 'Paired': color = cm.Paired
        elif colorMap == 'Pastel1': color = cm.Pastel1
        elif colorMap == 'Pastel2': color = cm.Pastel2
        elif colorMap == 'Set1': color = cm.Set1
        elif colorMap == 'Set2': color = cm.Set2
        elif colorMap == 'Set3': color = cm.Set3
        else:
            color = cm.jet
            colorMap = 'jet'

        bounds = list(range(sig_nbr+1))
        if composite != None:
            color, dummy = self._custom_listed_color_map(colorMap, len(bounds)+1, firstBlack=True)
        else:
            color, dummy = self._custom_listed_color_map(colorMap, len(bounds), firstBlack=False)
        norm = colors.BoundaryNorm(bounds, color.N)
        img = plt.imshow(cmap, cmap=color, interpolation=None, norm=norm)
        cbar = plt.colorbar(img, cmap=color, norm=norm, boundaries=bounds,
                            ticks=[x+0.5 for x in range(sig_nbr)])

        if composite != None:
            labels = ['None']
            labels.extend(composite)
            cbar.set_ticklabels(labels)
        else:
            cbar.set_ticklabels(list(range(1, sig_nbr+1)))
        img.get_axes().set_ylabel('spectrum #', rotation=270, labelpad=70)
        img.get_axes().yaxis.set_label_position("right")

        if suffix == None:
            plt.title('{0} Class Map'.format(self.label))
        else:
            plt.title('{0} Class Map - {1}'.format(self.label, suffix))
        plt.show()
        plt.clf()

    def plot_histo(self, path, cmap, em_nbr, suffix):
        import matplotlib.pyplot as plt
        plt.ioff()
        farray = np.ndarray.flatten(cmap)
        plt.hist(farray, bins=list(range(em_nbr+2)), align='left')
        if suffix == None:
            fout = osp.join(path, 'histo_{0}.png'.format(self.label))
        else:
            fout = osp.join(path, 'histo_{0}_{1}.png'.format(self.label, suffix))
        try:
            plt.savefig(fout)
        except IOError:
            raise IOError('in classifiers.output, no such file or directory: {0}'.format(path))
        plt.clf()

    def plot(self, path, cmap, ylabel='spectrum', colorMap='Accent', suffix=None):
        """
        Plot a classification map using matplotlib.

        Parameters:
            path: `string`
              The path where to put the plot.

            cmap: `numpy array`
                A classified map, (m x n x 1),
                the classes start at 0.

            ylabel: `string [default 'spectrum']`
                y axis label.

            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        import matplotlib.pyplot as plt
        import matplotlib.cm as cm
        from matplotlib import colors
        # fallback on jet colormap
        #color = cm.jet
        if colorMap == 'Accent': color = cm.Accent
        elif colorMap == 'Dark2': color = cm.Dark2
        elif colorMap == 'Paired': color = cm.Paired
        elif colorMap == 'Pastel1': color = cm.Pastel1
        elif colorMap == 'Pastel2': color = cm.Pastel2
        elif colorMap == 'Set1': color = cm.Set1
        elif colorMap == 'Set2': color = cm.Set2
        elif colorMap == 'Set3': color = cm.Set3
        else:
            color = cm.jet
            colorMap = 'jet'

        # class map start at zero internaly
        #cmap = cmap - 1

        sig_nbr = np.max(cmap)+1
        bounds = list(range(sig_nbr+1))
        color, dummy = self._custom_listed_color_map(colorMap, len(bounds), firstBlack=False)
        norm = colors.BoundaryNorm(bounds, color.N)
        img = plt.imshow(cmap, cmap=color, interpolation=None, norm=norm)
        cbar = plt.colorbar(img, cmap=color, norm=norm, boundaries=bounds,
                            ticks=[x+0.5 for x in range(sig_nbr)])

        cbar.set_ticklabels(list(range(1,sig_nbr+1)))
        img.get_axes().set_ylabel('{0} #'.format(ylabel), rotation=270, labelpad=70)
        img.get_axes().yaxis.set_label_position("right")

        if suffix == None:
            fout = osp.join(path, 'cmap_{0}.png'.format(self.label))
        else:
            fout = osp.join(path, 'cmap_{0}_{1}.png'.format(self.label, suffix))
        try:
            plt.savefig(fout)
        except IOError:
            raise IOError('in classifiers.output, no such file or directory: {0}'.format(path))
        plt.clf()

    def display(self, cmap, ylabel='spectrum', colorMap='Accent', suffix=None):
        """
        Display a classification map using matplotlib.

        Parameters:
            cmap: `numpy array`
                A classified map, (m x n x 1),
                the classes start at 0.

            ylabel: `string [default 'spectrum']`
                y axis label.

            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        import matplotlib.pyplot as plt
        import matplotlib.cm as cm
        from matplotlib import colors
        # fallback on jet colormap
        #color = cm.jet
        if colorMap == 'Accent': color = cm.Accent
        elif colorMap == 'Dark2': color = cm.Dark2
        elif colorMap == 'Paired': color = cm.Paired
        elif colorMap == 'Pastel1': color = cm.Pastel1
        elif colorMap == 'Pastel2': color = cm.Pastel2
        elif colorMap == 'Set1': color = cm.Set1
        elif colorMap == 'Set2': color = cm.Set2
        elif colorMap == 'Set3': color = cm.Set3
        else:
            color = cm.jet
            colorMap = 'jet'

        # class map start at zero internaly
        #cmap = cmap - 1

        sig_nbr = np.max(cmap)+1
        bounds = list(range(sig_nbr+1))
        color, dummy = self._custom_listed_color_map(colorMap, len(bounds), firstBlack=False)
        norm = colors.BoundaryNorm(bounds, color.N)
        img = plt.imshow(cmap, cmap=color, interpolation=None, norm=norm)
        cbar = plt.colorbar(img, cmap=color, norm=norm, boundaries=bounds,
                            ticks=[x+0.5 for x in range(sig_nbr)])

        cbar.set_ticklabels(list(range(1,sig_nbr+1)))
        img.get_axes().set_ylabel('{0} #'.format(ylabel), rotation=270, labelpad=70)
        img.get_axes().yaxis.set_label_position("right")

        if suffix == None:
            plt.title('{0} Class Map'.format(self.label))
        else:
            plt.title('{0} Class Map - {1}'.format(self.label, suffix))
        plt.show()
        plt.clf()


def _single_value_min(data, threshold):
    """
    Use a threshold to extract the minimum value along
    the data y axis.
    """
    amin = np.min(data)
    amax = np.max(data)
    limit = amin + (amax - amin) * threshold
    return data < limit

def _single_value_max(data, threshold):
    """
    Use a threshold to extract the minimum value along
    the data y axis.
    """
    amin = np.min(data)
    amax = np.max(data)
    limit = amax - (amax - amin) * threshold
    return data > limit


class SAM(object):
    """Classify a HSI cube using the spectral angle mapper algorithm
    and a spectral library."""

    def __init__(self):
        self.output = Output('SAM')
        self.check = Check('SAM')
        self.cmap = None
        self.angles = None
        # spectra number
        self.em_nbr = None
        # a float or a list of float
        self.threshold = None

    def classify(self, M, E, threshold=0.1):
        """
        Classify the HSI cube M with the spectral library E.

        Parameters:
            M: `numpy array`
              A HSI cube (m x n x p).

            E: `numpy array`
              A spectral library (N x p).

            threshold: `float [default 0.1] or list`
             * If float, threshold is applied on all the spectra.
             * If a list, individual threshold is applied on each
               spectrum, in this case the list must have the same
               number of threshold values than the number of spectra.
             * Threshold have values between 0.0 and 1.0.

        Returns: `numpy array`
              A class map (m x n x 1).
        """
        self.check.classify_input(M, E, threshold)

        if E.ndim == 1:
            self.em_nbr = 1
        else:
            self.em_nbr = E.shape[0]
        self.threshold = threshold
        h, w, numBands = M.shape
        M = np.reshape(M, (w*h, numBands))
        Mn = fmt.normalize(M)
        En = fmt.normalize(E)
        if E.ndim == 1:
            cmap, angles = self._class_single_pixel(Mn, En, threshold)
        else:
            cmap, angles = cls.SAM_classifier(Mn, En, threshold)

        self.cmap = np.reshape(cmap, (h, w))
        if E.ndim == 1:
            self.angles = np.reshape(angles, (h, w))
        else:
            self.angles = np.reshape(angles, (h, w, self.em_nbr))
        return self.cmap

    def _class_single_pixel(self, M, E, threshold):
        import pysptools.distance as dst
        angles = np.ndarray((M.shape[0], 1), dtype=np.float)
        for i in range(M.shape[0]):
            angles[i] = dst.SAM(M[i], E)
        cmap = _single_value_min(angles, threshold) * angles
        return cmap, angles

    def get_angles_map(self):
        """
        Returns: `numpy array`
            The angles array (m x n x spectra number).
        """
        self.check.cmap_exist(self.cmap, 'get_angles_map')
        return self.angles

    def get_angles_stats(self):
        """
        Returns: `dic`
             Angles stats.
        """
        self.check.cmap_exist(self.cmap, 'get_angles_stats')
        mm = {}
        for i in range(self.em_nbr):
            mm[i] = (np.min(self.angles[:,:,i]),
                     np.max(self.angles[:,:,i]))
        return mm

    def get_single_map(self, lib_idx, constrained=True):
        """
        Get individual classified map. See plot_single_map for
        a description.

        Parameters:
            path: `string`
              The path where to put the plot.

            lib_idx: `int or string`
                A number between 1 and the number of spectra in the library.

            constrained: `boolean [default True]`
                See plot_single_map for a description.

        Returns: `numpy array`
            The individual map (m x n x 1) associated to the lib_idx endmember.
        """
        self.check.index(lib_idx, self.em_nbr)
        return self.output.get_single_map(lib_idx, self.em_nbr, self.angles, self.threshold, constrained, inverse_scale=False)

    def plot_single_map(self, path, lib_idx, constrained=True, suffix=None):
        """
        Plot individual classified map. One for each spectrum.
        Note that each individual map is constrained by the others.
        This function is usefull to see the individual map that compose
        the final class map returned by the classify method. It help
        to define the spectra library. See the constrained parameter below.

        Parameters:
            path: `string`
              The path where to put the plot.

            lib_idx: `int or string`
                * A number between 1 and the number of spectra in the library.
                * 'all', plot all the individual maps.

            constrained: `boolean [default True]`
                * If constrained is True, print the individual maps as they compose the
                  final class map. Any potential intersection is removed in favor of
                  the lower value level for SAM and SID, or the nearest to 1 for NormXCorr. Use
                  this one to understand the final class map.
                * If constrained is False, print the individual maps without intersection
                  removed, as they are generated. Use this one to have the real match.

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self.check.plot_input(self.cmap, suffix, 'plot_single_map')
        self.check.index(lib_idx, self.em_nbr)
        self.output.plot_single_map(path, self.cmap, self.angles, lib_idx, self.em_nbr, self.threshold, constrained, suffix)

    def display_single_map(self, lib_idx, constrained=True, suffix=None):
        """
        Display individual classified map to a IPython Notebook. One for each spectrum.
        Note that each individual map is constrained by the others.
        This function is usefull to see the individual map that compose
        the final class map returned by the classify method. It help
        to define the spectra library. See the constrained parameter below.

        Parameters:
            lib_idx: `int or string`
                * A number between 1 and the number of spectra in the library.
                * 'all', plot all the individual maps.

            constrained: `boolean [default True]`
                * If constrained is True, print the individual maps as they compose the
                  final class map. Any potential intersection is removed in favor of
                  the lower value level for SAM and SID, or the nearest to 1 for NormXCorr. Use
                  this one to understand the final class map.
                * If constrained is False, print the individual maps without intersection
                  removed, as they are generated. Use this one to have the real match.

            suffix: `string [default None]`
              Add a suffix to the title.
        """
        self.check.plot_input(self.cmap, suffix, 'display_single_map')
        self.check.index(lib_idx, self.em_nbr)
        self.output.display_single_map(self.cmap, self.angles, lib_idx, self.em_nbr, self.threshold, constrained, suffix)

    def plot(self, path, colorMap='Accent', suffix=None):
        """
        Plot the class map.

        Parameters:
            path: `string`
              The path where to put the plot.

            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self.check.plot_input(self.cmap, suffix, 'plot')
        if self.em_nbr == 1:
            self.output.plot1(path, self.cmap, suffix=suffix)
        else:
            if self.threshold != None:
                sigSet = [x+1 for x in range(self.em_nbr)]
                self.output.plot_class_map(path, self.cmap, self.em_nbr+1, colorMap, suffix, composite=sigSet)
            else:
                self.output.plot_class_map(path, self.cmap, self.em_nbr, colorMap, suffix, composite=None)

    def display(self, colorMap='Accent', suffix=None):
        """
        Display the class map to a IPython Notebook.

        Parameters:
            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the title.
        """
        self.check.plot_input(self.cmap, suffix, 'display')
        if self.em_nbr == 1:
            self.output.display1(self.cmap, suffix=suffix)
        else:
            if self.threshold != None:
                sigSet = [x+1 for x in range(self.em_nbr)]
                self.output.display_class_map(self.cmap, self.em_nbr+1, colorMap, suffix, composite=sigSet)
            else:
                self.output.display_class_map(self.cmap, self.em_nbr, colorMap, suffix, composite=None)

    def plot_histo(self, path, suffix=None):
        """
        Plot the histogram.

        Parameters:
            path: `string`
              The path where to put the plot.

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self.check.plot_input(self.cmap, suffix, 'plot_histo')
        self.output.plot_histo(path, self.cmap, self.em_nbr, suffix)


class SID(object):
    """Classify a HSI cube using the spectral information divergence
    algorithm and a spectral library."""

    def __init__(self):
        self.output = Output('SID')
        self.check = Check('SID')
        self.cmap = None
        self.sid = None
        # endmembers number
        self.em_nbr = None
        self.threshold = None

    def classify(self, M, E, threshold=0.1):
        """
        Classify the HSI cube M with the spectral library E.

        Parameters:
            M: `numpy array`
              A HSI cube (m x n x p).

            E: `numpy array`
              A spectral library (N x p).

            threshold: `float [default 0.1] or list`
             * If float, threshold is applied on all the spectra.
             * If a list, individual threshold is applied on each
               spectrum, in this case the list must have the same
               number of threshold values than the number of spectra.
             * Threshold have values between 0.0 and 1.0.

        Returns: `numpy array`
              A class map (m x n x 1).
        """
        self.check.classify_input(M, E, threshold)

        if E.ndim == 1:
            self.em_nbr = 1
        else:
            self.em_nbr = E.shape[0]
        self.threshold = threshold
        h, w, numBands = M.shape
        M = np.reshape(M, (w*h, numBands))
        Mn = fmt.normalize(M)
        En = fmt.normalize(E)
        if E.ndim == 1:
            cmap, sid = self._class_single_pixel(Mn, En, threshold)
        else:
            cmap, sid = cls.SID_classifier(Mn, En, threshold)

        self.cmap = np.reshape(cmap, (h, w))
        if E.ndim == 1:
            self.sid = np.reshape(sid, (h, w))
        else:
            self.sid = np.reshape(sid, (h, w, self.em_nbr))
        return self.cmap

    def _class_single_pixel(self, M, E, threshold):
        import pysptools.distance as dst
        sid = np.ndarray((M.shape[0], 1), dtype=np.float)
        for i in range(M.shape[0]):
            sid[i] = dst.SID(M[i], E)
        cmap = _single_value_min(sid, threshold) * sid
        return cmap, sid

    def get_SID_map(self):
        """
        Returns: `numpy array`
            The SID array (m x n x spectra number).
        """
        self.check.cmap_exist(self.cmap, 'get_angles_map')
        return self.sid

    def get_single_map(self, lib_idx, constrained=True):
        """
        Get individual classified map. See plot_single_map for
        a description.

        Parameters:
            path: `string`
              The path where to put the plot.

            lib_idx: `int or string`
                A number between 1 and the number of spectra in the library.

            constrained: `boolean [default True]`
                See plot_single_map for a description.

        Returns: `numpy array`
            The individual map (m x n x 1) associated to the lib_idx endmember.
        """
        self.check.index(lib_idx, self.em_nbr)
        return self.output.get_single_map(lib_idx, self.em_nbr, self.sid, self.threshold, constrained, inverse_scale=False)

    def plot_single_map(self, path, lib_idx, constrained=True, suffix=None):
        """
        Plot individual classified map. One for each spectrum.
        Note that each individual map is constrained by the others.
        This function is usefull to see the individual map that compose
        the final class map returned by the classify method. It help
        to define the spectra library. See the constrained parameter below.

        Parameters:
            path: `string`
              The path where to put the plot.

            lib_idx: `int or string`
                * A number between 1 and the number of spectra in the library.
                * 'all', plot all the individual maps.

            constrained: `boolean [default True]`
                * If constrained is True, print the individual maps as they compose the
                  final class map. Any potential intersection is removed in favor of
                  the lower value level for SAM and SID, or the nearest to 1 for NormXCorr. Use
                  this one to understand the final class map.
                * If constrained is False, print the individual maps without intersection
                  removed, as they are generated. Use this one to have the real match.

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self.check.plot_input(self.cmap, suffix, 'plot_single_map')
        self.check.index(lib_idx, self.em_nbr)
        self.output.plot_single_map(path, self.cmap, self.sid, lib_idx, self.em_nbr, self.threshold, constrained, suffix)

    def display_single_map(self, lib_idx, constrained=True, suffix=None):
        """
        Display individual classified map to a IPython Notebook. One for each spectrum.
        Note that each individual map is constrained by the others.
        This function is usefull to see the individual map that compose
        the final class map returned by the classify method. It help
        to define the spectra library. See the constrained parameter below.

        Parameters:
            lib_idx: `int or string`
                * A number between 1 and the number of spectra in the library.
                * 'all', plot all the individual maps.

            constrained: `boolean [default True]`
                * If constrained is True, print the individual maps as they compose the
                  final class map. Any potential intersection is removed in favor of
                  the lower value level for SAM and SID, or the nearest to 1 for NormXCorr. Use
                  this one to understand the final class map.
                * If constrained is False, print the individual maps without intersection
                  removed, as they are generated. Use this one to have the real match.

            suffix: `string [default None]`
              Add a suffix to the title.
        """
        self.check.plot_input(self.cmap, suffix, 'display_single_map')
        self.check.index(lib_idx, self.em_nbr)
        self.output.display_single_map(self.cmap, self.sid, lib_idx, self.em_nbr, self.threshold, constrained, suffix)

    def plot(self, path, colorMap='Accent', suffix=None):
        """
        Plot the class map.

        Parameters:
            path: `string`
              The path where to put the plot.

            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self.check.plot_input(self.cmap, suffix, 'plot')
        if self.em_nbr == 1:
            self.output.plot1(path, self.cmap, suffix=suffix)
        else:
            if self.threshold != None:
                sigSet = [x+1 for x in range(self.em_nbr)]
                self.output.plot_class_map(path, self.cmap, self.em_nbr+1, colorMap, suffix, composite=sigSet)
            else:
                self.output.plot_class_map(path, self.cmap, self.em_nbr, colorMap, suffix, composite=None)

    def display(self, colorMap='Accent', suffix=None):
        """
        Display the class map to a IPython Notebook.

        Parameters:
            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the title.
        """
        self.check.plot_input(self.cmap, suffix, 'display')
        if self.em_nbr == 1:
            self.output.display1(self.cmap, suffix=suffix)
        else:
            if self.threshold != None:
                sigSet = [x+1 for x in range(self.em_nbr)]
                self.output.display_class_map(self.cmap, self.em_nbr+1, colorMap, suffix, composite=sigSet)
            else:
                self.output.display_class_map(self.cmap, self.em_nbr, colorMap, suffix, composite=None)

    def plot_histo(self, path, suffix=None):
        """
        Plot the histogram.

        Parameters:
            path: `string`
              The path where to put the plot.

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self.check.plot_input(self.cmap, suffix, 'plot_histo')
        self.output.plot_histo(path, self.cmap, self.em_nbr, suffix)


class NormXCorr(object):
    """Classify a HSI cube using the normalized cross correlation
    algorithm and a spectral library."""

    def __init__(self):
        self.output = Output('NormXCorr')
        self.check = Check('NormXCorr')
        self.cmap = None
        self.corr = None
        # endmembers number
        self.em_nbr = None
        self.threshold = None

    def classify(self, M, E, threshold=0.01):
        """
        Classify the HSI cube M with the spectral library E.

        Parameters:
            M: `numpy array`
              A HSI cube (m x n x p).

            E: `numpy array`
              A spectral library (N x p).

            threshold: `float [default 0.1] or list`
             * If float, threshold is applied on all the spectra.
             * If a list, individual threshold is applied on each
               spectrum, in this case the list must have the same
               number of threshold values than the number of spectra.
             * Threshold have values between 0.0 and 1.0.

        Returns: `numpy array`
              A class map (m x n x 1).
        """
        self.check.classify_input(M, E, threshold)

        if E.ndim == 1:
            self.em_nbr = 1
        else:
            self.em_nbr = E.shape[0]
        if type(threshold) == float:
            self.threshold = 1 - threshold
        if type(threshold) == list:
            self.threshold = [1 - x for x in threshold]
        h, w, numBands = M.shape
        M = np.reshape(M, (w*h, numBands))
        Mn = fmt.normalize(M)
        En = fmt.normalize(E)
        if E.ndim == 1:
            cmap, corr = self._class_single_pixel(Mn, En, threshold)
        else:
            cmap, corr = cls.NormXCorr_classifier(Mn, En, threshold)

        self.cmap = np.reshape(cmap, (h, w))
        if E.ndim == 1:
            self.corr = np.reshape(corr, (h, w))
        else:
            self.corr = np.reshape(corr, (h, w, self.em_nbr))
        return self.cmap

    def _class_single_pixel(self, M, E, threshold):
        import pysptools.distance as dst
        corr = np.ndarray((M.shape[0], 1), dtype=np.float)
        for i in range(M.shape[0]):
            corr[i] = dst.NormXCorr(M[i], E)
        cmap = _single_value_max(corr, threshold) * corr
        return cmap, corr

    def get_NormXCorr_map(self):
        """
        Returns: `numpy array`
            The NormXCorr array (m x n x spectra number).
        """
        self.check.cmap_exist(self.cmap, 'get_angles_map')
        return self.corr

    def get_single_map(self, lib_idx, constrained=True):
        """
        Get individual classified map. See plot_single_map for
        a description.

        Parameters:
            path: `string`
              The path where to put the plot.

            lib_idx: `int or string`
                A number between 1 and the number of spectra in the library.

            constrained: `boolean [default True]`
                See plot_single_map for a description.

        Returns: `numpy array`
            The individual map (m x n x 1) associated to the lib_idx endmember.
        """
        self.check.index(lib_idx, self.em_nbr)
        return self.output.get_single_map(lib_idx, self.em_nbr, self.corr, self.threshold, constrained, inverse_scale=False)

    def plot_single_map(self, path, lib_idx, constrained=True, suffix=None):
        """
        Plot individual classified map. One for each spectrum.
        Note that each individual map is constrained by the others.
        This function is usefull to see the individual map that compose
        the final class map returned by the classify method. It help
        to define the spectra library. See the constrained parameter below.

        Parameters:
            path: `string`
              The path where to put the plot.

            lib_idx: `int or string`
                * A number between 1 and the number of spectra in the library.
                * 'all', plot all the individual maps.

            constrained: `boolean [default True]`
                * If constrained is True, print the individual maps as they compose the
                  final class map. Any potential intersection is removed in favor of
                  the lower value level for SAM and SID, or the nearest to 1 for NormXCorr. Use
                  this one to understand the final class map.
                * If constrained is False, print the individual maps without intersection
                  removed, as they are generated. Use this one to have the real match.

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self.check.plot_input(self.cmap, suffix, 'plot_single_map')
        self.check.index(lib_idx, self.em_nbr)
        self.output.plot_single_map(path, self.cmap, self.corr, lib_idx, self.em_nbr, self.threshold, constrained, suffix)

    def display_single_map(self, lib_idx, constrained=True, suffix=None):
        """
        Display individual classified map to a IPython Notebook. One for each spectrum.
        Note that each individual map is constrained by the others.
        This function is usefull to see the individual map that compose
        the final class map returned by the classify method. It help
        to define the spectra library. See the constrained parameter below.

        Parameters:
            lib_idx: `int or string`
                * A number between 1 and the number of spectra in the library.
                * 'all', plot all the individual maps.

            constrained: `boolean [default True]`
                * If constrained is True, print the individual maps as they compose the
                  final class map. Any potential intersection is removed in favor of
                  the lower value level for SAM and SID, or the nearest to 1 for NormXCorr. Use
                  this one to understand the final class map.
                * If constrained is False, print the individual maps without intersection
                  removed, as they are generated. Use this one to have the real match.

            suffix: `string [default None]`
              Add a suffix to the title.
        """
        self.check.plot_input(self.cmap, suffix, 'display_single_map')
        self.check.index(lib_idx, self.em_nbr)
        self.output.display_single_map(self.cmap, self.corr, lib_idx, self.em_nbr, self.threshold, constrained, suffix)

    def plot(self, path, colorMap='Accent', suffix=None):
        """
        Plot the class map.

        Parameters:
            path: `string`
              The path where to put the plot.

            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self.check.plot_input(self.cmap, suffix, 'plot')
        if self.em_nbr == 1:
            self.output.plot1(path, self.cmap, suffix=suffix)
        else:
            if self.threshold != None:
                sigSet = [x+1 for x in range(self.em_nbr)]
                self.output.plot_class_map(path, self.cmap, self.em_nbr+1, colorMap, suffix, composite=sigSet)
            else:
                self.output.plot_class_map(path, self.cmap, self.em_nbr, colorMap, suffix, composite=None)

    def display(self, colorMap='Accent', suffix=None):
        """
        Display the class map to a IPython Notebook.

        Parameters:
            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the title.
        """
        self.check.plot_input(self.cmap, suffix, 'display')
        if self.em_nbr == 1:
            self.output.display1(self.cmap, suffix=suffix)
        else:
            if self.threshold != None:
                sigSet = [x+1 for x in range(self.em_nbr)]
                self.output.display_class_map(self.cmap, self.em_nbr+1, colorMap, suffix, composite=sigSet)
            else:
                self.output.display_class_map(self.cmap, self.em_nbr, colorMap, suffix, composite=None)

    def plot_histo(self, path, suffix=None):
        """
        Plot the histogram.

        Parameters:
            path: `string`
              The path where to put the plot.

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self.check.plot_input(self.cmap, suffix, 'plot_histo')
        self.output.plot_histo(path, self.cmap, self.em_nbr, suffix)
