#
#------------------------------------------------------------------------------
# Copyright (c) 2013-2014, Christian Therien
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#------------------------------------------------------------------------------
#
# clu_int.py - This file is part of the PySptools package.
#

"""
KMeans class
"""



import os.path as osp
import numpy as np
import sklearn.cluster as cluster


class Check(object):
    """ Validate inputs for the classifiers classes """

    err1 = 'in {0}.predict(), M is not a numpy.array'
    err3 = 'in {0}.predict(), M have {1} dimension(s), expected 3 dimensions'
    err10 = 'in {0} class, call predict before calling {1}'
    err11 = 'in {0} class, suffix is not of str type'
    err12 = 'in {0}.classify(), the M spectrum length is different to the E spectrum length'

    def __init__(self, label):
        self.label = label

    def plot_input(self, cluster, suffix, method_name):
        if cluster == None:
            raise RuntimeError(self.err10.format(self.label, method_name))
        if type(suffix) is not str and suffix != None:
            raise TypeError(self.err11.format(self.label))

    def predict_input(self, M):
        if type(M) is not np.ndarray:
            raise RuntimeError(self.err1.format(self.label))
        if M.ndim != 3:
            raise RuntimeError(self.err3.format(self.label, M.ndim))


class KMeans(object):
    """ KMeans clustering algorithm adapted to hyperspectral imaging """

    def __init__(self):
        self.check = Check('KMeans')
        self.cluster = None
        self.n_clusters = None

    def predict(self, M, n_clusters=5, n_jobs=1, init='k-means++'):
        """
        KMeans clustering algorithm adapted to hyperspectral imaging.
        It is a simple wrapper to the scikit-learn version.

        Parameters:
            M: `numpy array`
              A HSI cube (m x n x p).

            n_clusters: `int [default 5]`
                The number of clusters to generate.

            n_jobs: `int [default 1]`
                Taken from scikit-learn doc:
                The number of jobs to use for the computation. This works by breaking down the pairwise matrix into n_jobs even slices and computing them in parallel.
                If -1 all CPUs are used. If 1 is given, no parallel computing code is used at all, which is useful for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one are used.

            init: `string or array [default 'k-means++']`
                Taken from scikit-learn doc: Method for initialization, defaults to `k-means++`:
                `k-means++` : selects initial cluster centers for k-mean clustering in a smart way to speed up convergence. See section Notes in k_init for more details.
                `random`: choose k observations (rows) at random from data for the initial centroids.
                If an ndarray is passed, it should be of shape (n_clusters, n_features) and gives the initial centers.


        Returns: `numpy array`
              A cluster map (m x n x c), c is the clusters number .

        """
        self.check.predict_input(M)
        h, w, numBands = M.shape
        self.n_clusters = n_clusters
        X = np.reshape(M, (w*h, numBands))
        clf = cluster.KMeans(n_clusters=n_clusters, n_jobs=n_jobs, init=init)
        cls = clf.fit_predict(X)
        self.cluster = np.reshape(cls, (h, w))
        return self.cluster

    def plot(self, path, colorMap='Accent', suffix=None):
        """
        Plot the cluster map.

        Parameters:
            path: `string`
              The path where to put the plot.

            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self.check.plot_input(self.cluster, suffix, 'plot')
        import matplotlib.pyplot as plt
        import matplotlib.cm as cm
        from matplotlib import colors
        plt.ioff()
        # fallback on jet colormap
        if colorMap == 'Accent': color = cm.Accent
        elif colorMap == 'Dark2': color = cm.Dark2
        elif colorMap == 'Paired': color = cm.Paired
        elif colorMap == 'Pastel1': color = cm.Pastel1
        elif colorMap == 'Pastel2': color = cm.Pastel2
        elif colorMap == 'Set1': color = cm.Set1
        elif colorMap == 'Set2': color = cm.Set2
        elif colorMap == 'Set3': color = cm.Set3
        else:
            color = cm.jet
            colorMap = 'jet'

        bounds = list(range(self.n_clusters+1))
        color, dummy = self._custom_listed_color_map(colorMap, len(bounds))
        norm = colors.BoundaryNorm(bounds, color.N)
        img = plt.imshow(self.cluster, cmap=color, interpolation=None, norm=norm)
        cbar = plt.colorbar(img, cmap=color, norm=norm, boundaries=bounds,
                            ticks=[x+0.5 for x in range(self.n_clusters)])

        cbar.set_ticklabels(list(range(1,self.n_clusters+1)))
        img.get_axes().set_ylabel('class #', rotation=270, labelpad=70)
        img.get_axes().yaxis.set_label_position("right")

        if suffix == None:
            fout = osp.join(path, 'kmeans.png')
        else:
            fout = osp.join(path, 'kmeans_{0}.png'.format(suffix))
        try:
            plt.savefig(fout)
        except IOError:
            raise IOError('in cluster.KMeans, no such file or directory: {0}'.format(path))
        plt.clf()

    def display(self, colorMap='Accent', suffix=None):
        """
        Display the cluster map.

        Parameters:
            path: `string`
              The path where to put the plot.

            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the title.
        """
        self.check.plot_input(self.cluster, suffix, 'display')
        import matplotlib.pyplot as plt
        import matplotlib.cm as cm
        from matplotlib import colors
        # fallback on jet colormap
        if colorMap == 'Accent': color = cm.Accent
        elif colorMap == 'Dark2': color = cm.Dark2
        elif colorMap == 'Paired': color = cm.Paired
        elif colorMap == 'Pastel1': color = cm.Pastel1
        elif colorMap == 'Pastel2': color = cm.Pastel2
        elif colorMap == 'Set1': color = cm.Set1
        elif colorMap == 'Set2': color = cm.Set2
        elif colorMap == 'Set3': color = cm.Set3
        else:
            color = cm.jet
            colorMap = 'jet'

        bounds = list(range(self.n_clusters+1))
        color, dummy = self._custom_listed_color_map(colorMap, len(bounds))
        norm = colors.BoundaryNorm(bounds, color.N)
        img = plt.imshow(self.cluster, cmap=color, interpolation=None, norm=norm)
        cbar = plt.colorbar(img, cmap=color, norm=norm, boundaries=bounds,
                            ticks=[x+0.5 for x in range(self.n_clusters)])

        cbar.set_ticklabels(list(range(1,self.n_clusters+1)))
        img.get_axes().set_ylabel('class #', rotation=270, labelpad=70)
        img.get_axes().yaxis.set_label_position("right")

        if suffix == None:
            plt.title('K-Means')
        else:
            plt.title('K-Means - {0}'.format(suffix))
        plt.show()
        plt.clf()

    def _custom_listed_color_map(self, name, N):
        import matplotlib.cm as cm
        from matplotlib import colors
        mp = cm.datad[name]
        new_mp1 = {'blue': colors.makeMappingArray(N-1, mp['blue']),
                  'green': colors.makeMappingArray(N-1, mp['green']),
                  'red': colors.makeMappingArray(N-1, mp['red'])}
        new_mp2 = []
        new_mp2.extend(list(zip(new_mp1['red'], new_mp1['green'], new_mp1['blue'])))
        return colors.ListedColormap(new_mp2, N=N-1), new_mp2
