#
#------------------------------------------------------------------------------
# Copyright (c) 2013-2015, Christian Therien
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#------------------------------------------------------------------------------
#
# svc.py - This file is part of the PySptools package.
#

from __future__ import print_function

import os.path as osp
import numpy as np
from sklearn import svm
from sklearn import preprocessing
import matplotlib.pyplot as plt


class SVC(object):
    """
    Suppot Vector Supervised Classification (SVC) of a HSI cube with the
    use of regions of interest (ROIs).

    This class is largely a wrapper to the scikit-learn SVC class. The goal is
    to ease the use of the scikit-learn SVM implementation when applied
    to hyperspectral cubes.

    The ROIs classifiers can be rectangles or polygons. They must be VALID, no check is made
    upon the validity of these geometric figures.
    """

    def __init__(self):
        self.clf = None
        self.cmap = None
        self.mask = None
        self.n_clusters = None

    def fit(self, M, ROIs, class_weight=None, cache_size=200, coef0=0.0, degree=3,
            gamma=0.0, kernel='rbf', max_iter=-1, probability=False, random_state=None,
            shrinking=True, tol=0.001, verbose=False):
        """
        Fit the HS cube M with the use of ROIs. The parameters following 'M' and 'ROIs' are the
        one defined by the scikit-learn sklearn.svm.SVC class.

        Parameters:
            M: `numpy array`
              A HSI cube (m x n x p).

            ROIs: `ROIs type`
                Regions of interest instance.

            Others parameters: `see the sklearn.svm.SVC class parameters`
                Note: the C parameter is set to 1, the result of this setting is that
                the class_weight is relative to C and that the first value of
                class_weight is the background.
                An example: you wish to fit two classes "1" and "2" with the help
                of one ROI for each, you declare class_weight like this:
                class_weight={0:1,1:10,2:10}
                0: is always the background and is set to 1, 1: is the first class,
                2: is the second. A value of 10 for both classes give good results.

        Returns: `class`
            The sklearn.svm.SVC class is returned.
        """
        self.n_clusters = ROIs.n_clusters()
        self.mask = np.zeros((M.shape[0],M.shape[1]), dtype=np.int)
        # mask: a value of zero is X, and a value of 1,2... is y, 1 for the first
        # roi, 2 for the next roi ...
        i = 0
        for id_, rois in ROIs.get_next():
            i += 1
            self._post_to_mask(rois, i)
        X_cube = self._get_X(M)
        # 0 is the M class
        y_cube = np.zeros(X_cube.shape[0], dtype=np.int)
        X = np.array(X_cube)
        y = np.array(y_cube)

        i = 0
        for id_, rois in ROIs.get_next():
            i += 1
            for r in rois:
                X_roi = self._crop(M, i, r)
                y_roi = np.zeros(X_roi.shape[0], dtype=np.int) + i
                X = np.concatenate((X, X_roi))
                y = np.concatenate((y, y_roi))

        X_scaled = preprocessing.scale(X)
        self.clf = svm.SVC(C=1.0, class_weight=class_weight, cache_size=cache_size,
                    coef0=coef0, degree=degree, gamma=gamma, kernel=kernel, max_iter=max_iter,
                    probability=probability, shrinking=shrinking, tol=tol, verbose=verbose)
        self.clf.fit(X_scaled, y)
        return self.clf

    def classify(self, M):
        """
        Classify a hyperspectral cube using the ROIs defined clusters.

        Parameters:
            M: `numpy array`
              A HSI cube (m x n x p).

        Returns: `numpy array`
              A class map (m x n x 1).
        """
        img = self._convert2D(M)
        image_scaled = preprocessing.scale(img)
        cls = self.clf.predict(image_scaled)
        self.cmap = self._convert3d(cls, M.shape[0], M.shape[1])
        return self.cmap

    def get_class_map(self):
        return self.cmap

    def _crop(self, M, roi_id,  r):
        if 'rec' in r:
            bbox = r['rec']
            return self._convert2D(M[bbox[0]:bbox[2],bbox[1]:bbox[3],:])
        if 'poly' in r:
            masked = np.sum(self.mask == roi_id)
            linear_cube = np.ndarray((masked, M.shape[2]), dtype=np.float)
            i = 0
            for x in xrange(M.shape[0]):
                for y in xrange(M.shape[1]):
                    if self.mask[x,y] == roi_id:
                        linear_cube[i] = M[x,y,:]
                        i += 1
            return linear_cube

    def _convert2D(self, M):
        h, w, numBands = M.shape
        return np.reshape(M, (w*h, numBands))

    def _convert3d(self, M, h, w):
        return np.reshape(M, (h, w))

    def _get_X(self, M):
        masked = np.sum(self.mask > 0)
        not_masked = M.shape[0] * M.shape[1] - masked
        linear_cube = np.ndarray((not_masked, M.shape[2]), dtype=np.float)
        i = 0
        for x in xrange(M.shape[0]):
            for y in xrange(M.shape[1]):
                if self.mask[x,y] == 0:
                    linear_cube[i] = M[x,y,:]
                    i += 1
        return linear_cube

    def _post_to_mask(self, rois, id):
        for r in rois:
            if 'rec' in r:
                x1,y1,x2,y2 = r['rec']
                for x in xrange(self.mask.shape[0]):
                    for y in xrange(self.mask.shape[1]):
                        if (x >= x1 and x < x2) and (y >= y1 and y < y2):
                            self.mask[x,y] = id
            if 'poly' in r:
                import matplotlib.patches as patches
                poly1 = patches.Polygon(r['poly'], closed=True)
                for i in range(self.mask.shape[0]):
                    for j in range(self.mask.shape[1]):
                        if poly1.get_path().contains_point((i,j)) == True:
                            self.mask[i,j] = id

    def _custom_listed_color_map(self, name, N, firstBlack=False):
        """ add the black color in front of 'name' color """
        import matplotlib.cm as cm
        from matplotlib import colors
        mp = cm.datad[name]
        new_mp1 = {'blue': colors.makeMappingArray(N-1, mp['blue']),
                  'green': colors.makeMappingArray(N-1, mp['green']),
                  'red': colors.makeMappingArray(N-1, mp['red'])}
        new_mp2 = []
        new_mp2.extend(zip(new_mp1['red'], new_mp1['green'], new_mp1['blue']))
        if firstBlack == True:
            new_mp2 = [(0,0,0)]+new_mp2 # the black color
        return colors.ListedColormap(new_mp2, N=N-1), new_mp2

    def _out(self, path, labels, fname, colorMap, img, suffix):
        #self.check.plot_input(self.cmap, suffix, 'plot')
        import matplotlib.pyplot as plt
        import matplotlib.cm as cm
        from matplotlib import colors
        if path != None:
            plt.ioff()
        # fallback on jet colormap
        if colorMap == 'Accent': color = cm.Accent
        elif colorMap == 'Dark2': color = cm.Dark2
        elif colorMap == 'Paired': color = cm.Paired
        elif colorMap == 'Pastel1': color = cm.Pastel1
        elif colorMap == 'Pastel2': color = cm.Pastel2
        elif colorMap == 'Set1': color = cm.Set1
        elif colorMap == 'Set2': color = cm.Set2
        elif colorMap == 'Set3': color = cm.Set3
        else:
            color = cm.jet
            colorMap = 'jet'

        bounds = range(self.n_clusters+2)
        color, dummy = self._custom_listed_color_map(colorMap, len(bounds)+1, firstBlack=True)
        norm = colors.BoundaryNorm(bounds, color.N)
        img = plt.imshow(img, cmap=color, interpolation=None, norm=norm)
        cbar = plt.colorbar(img, cmap=color, norm=norm, boundaries=bounds,
                            ticks=[x+0.5 for x in range(self.n_clusters+1)])

        if labels == None:
            sigSet = [x+1 for x in range(self.n_clusters)]
            lbls = ['None']
            lbls.extend(sigSet)
        else:
            lbls = ['None']
            lbls.extend(labels)
        cbar.set_ticklabels(lbls)
        if labels == None:
            img.get_axes().set_ylabel('class #', rotation=270, labelpad=70)
        img.get_axes().yaxis.set_label_position("right")

        if path != None:
            if suffix == None:
                fout = osp.join(path, '{0}.png'.format(fname))
            else:
                fout = osp.join(path, '{0}_{1}.png'.format(fname, suffix))
            try:
                plt.savefig(fout)
            except IOError:
                raise IOError('in classification.SVC, no such file or directory: {0}'.format(path))
        else:
            if suffix == None:
                plt.title('{0}'.format(fname))
            else:
                plt.title('{0} - {1}'.format(fname, suffix))
            plt.show()
        plt.clf()

    def plot(self, path, labels=None, colorMap='Accent', suffix=None):
        """
        Plot the class map.

        Parameters:
            path: `string`
              The path where to put the plot.

            labels: `string list`
              A labels list.

            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self._out(path, labels, 'SVC', colorMap, self.cmap, suffix)

    def plot_ROIs(self, path, labels=None, colorMap='Accent', suffix=None):
        """
        Plot the ROIs.

        Parameters:
            path: `string`
              The path where to put the plot.

            labels: `string list`
              A labels list.

            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self._out(path, labels, 'ROIs', colorMap, self.mask, suffix)

    def display(self, labels=None, colorMap='Accent', suffix=None):
        """
        Display the class map.

        Parameters:
            labels: `string list`
              A labels list.

            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self._out(None, labels, 'SVC', colorMap, self.cmap, suffix)

    def display_ROIs(self, labels=None, colorMap='Accent', suffix=None):
        """
        Display the ROIs.

        Parameters:
            labels: `string list`
              A labels list.

            colorMap: `string [default 'Accent']`
              A color map element of
              ['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3'],
              "Accent" is the default and it fall back on "Jet".

            suffix: `string [default None]`
              Add a suffix to the file name.
        """
        self._out(None, labels, 'Mask', colorMap, self.mask, suffix)


class ROIs(object):
    """
    Manage regions of interest (ROIs).
    """

    def __init__(self):
        self._rois = []
        self._n_clusters = 0

    def add(self, id, *rois):
        """
        Add an identified ROI.

        Parameters:
            id: `string`
              The class (or cluster) name.

            *rois: `dictionary list`
              Each parameter, a dictionary, represent a rectangle or a polygon.
              For a rectangle: {'rec': (upper_left_x, upper_left_y, lower_right_x, lower_right_y)}
              For a polygone: {'poly': ((x1,y1),(x2,y2), ...)}, the polygon don't need to be close.
              You can define one or more rectangle and/or polygon for a same cluster.
              The polygon and the rectangle must be VALID.
        """
        self._rois.append((id, rois))
        self._n_clusters += 1

    def n_clusters(self):
        return self._n_clusters

    def get_next(self):
        """
        Iterator, return at each step: the cluster name and a ROI list.

        Return: `tuple`
            Cluster name, ROI list.
        """
        for r in self._rois:
            id_ = r[0]
            rois = r[1]
            yield id_, rois

    def get_labels(self):
        """
        Return a labels list.

        Return: `list`
            A labels list.
        """
        labels = []
        for r in self._rois:
            labels.append(r[0])
        return labels
