#
#------------------------------------------------------------------------------
# Copyright (c) 2013-2014, Christian Therien
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#------------------------------------------------------------------------------
#
# amaps_int.py - This file is part of the PySptools package.
#

"""
UCLS, NNLS, FCLS classes
"""


import os.path as osp
import numpy as np

import pysptools.formatting as fmt
import amaps


class AmapsError(object):
    """ Validate inputs for the abundance maps generators """

    err1 = 'in {0}.map(), M is not a numpy.array'
    err2 = 'in {0}.map(), U is not a numpy.array'
    err3 = 'in {0}.map(), M have {1} dimension(s), expected 3 dimensions'
    err4 = 'in {0}.map(), U have {1} dimension(s), expected 2 dimensions'
    err10 = 'in {0} class, call map before calling {1}'
    err11 = 'in {0} class, suffix is not of str type'
    err12 = 'in {0}.map(), the M signal length is different to the U signal length'

    def __init__(self, label):
        self.label = label

    def validate_plot_input(self, suffix, method_name):
        if self.amap == None:
            raise RuntimeError(self.err10.format(self.label, method_name))
        if type(suffix) is not str and suffix != None:
            raise TypeError(self.err11.format(self.label))

    def validate(self, M, E):
        if type(M) is not np.ndarray:
            raise RuntimeError(self.err1.format(self.label))
        if type(E) is not np.ndarray:
            raise RuntimeError(self.err2.format(self.label))
        if M.ndim != 3:
            raise RuntimeError(self.err3.format(self.label, M.ndim))
        if E.ndim != 2:
            raise RuntimeError(self.err4.format(self.label, E.ndim))
        if M.shape[2] != E.shape[1]:
            raise RuntimeError(self.err12.format(self.label))


def _plot_abundance_map(path, amap, map_type, suffix=None, cmap='jet'):
    """ Plot an abundance map using matplotlib """
    import matplotlib.pyplot as plt
    plt.ioff()
    for i in xrange(amap.shape[2]):
        m = amap[:,:,i]
        img = plt.imshow(m)
        img.set_cmap(cmap)
        plt.colorbar()
        if suffix == None:
            fout = osp.join(path, 'amap_{0}__{1}.png'.format(map_type, i+1))
        else:
            fout = osp.join(path, 'amap_{0}__{1}_{2}.png'.format(map_type, i+1, suffix))
        try:
            plt.savefig(fout)
        except IOError:
            raise IOError('in _plot_abundance_map, no such file or directory: {0}'.format(path))
        plt.clf()


def _display_abundance_map(amap, map_type, suffix, cmap='jet'):
    """ Plot an abundance map using matplotlib """
    import matplotlib.pyplot as plt
    for i in xrange(amap.shape[2]):
        m = amap[:,:,i]
        img = plt.imshow(m)
        img.set_cmap(cmap)
        plt.colorbar()
        if suffix == None:
            plt.title('{0} Inversion - EM{1}'.format(map_type, i+1))
        else:
            plt.title('{0} Inversion - EM{1} - {2}'.format(map_type, i+1, suffix))
        plt.show()
        plt.clf()


class UCLS(AmapsError):
    """
    Performs unconstrained least squares abundance estimation.
    """

    def __init__(self):
        AmapsError.__init__(self, 'UCLS')
        self.amap = None

    def map(self, M, U, normalize=False):
        """
        Performs unconstrained least squares abundance estimation on
        the HSI cube M using the signals library U.

        Parameters:
       	  M: `numpy array`
             A HSI cube (m x n x p).

          U: `numpy array`
             A spectral library of endmembers (q x p).

          normalize: `boolean [default False]`
             If True, M and U are normalized before doing the signals mapping.

        Returns: `numpy array`
              An abundance maps (m x n x q).
        """
        self.validate(M, U)
        h,w,numBands = M.shape
        if normalize == True:
            M = fmt.normalize(M)
            U = fmt.normalize(U)
        Mr = np.reshape(M, (w*h, numBands))
        amap2D = amaps.UCLS(Mr, U)
        self.amap = np.reshape(amap2D, (h, w, U.shape[0]))
        return self.amap

    def plot(self, path, colorMap='jet', suffix=None):
        """
        Plot the abundance maps.

        Parameters:
            path: `string`
              The path where to put the plot.

            colorMap: `string [default jet]`
              A matplotlib color map.

            suffix: `string [default None]`
              Suffix to add to the file name.
        """
        self.validate_plot_input(suffix, 'plot')
        _plot_abundance_map(path, self.amap, 'UCLS', suffix, colorMap)

    def display(self, colorMap='jet', suffix=None):
        """
        Plot the abundance maps on a IPython Notebook.

        Parameters:
            colorMap: `string [default jet]`
              A matplotlib color map.

            suffix: `string [default None]`
              Suffix to add to the title.
        """
        _display_abundance_map(self.amap, 'UCLS', suffix, cmap=colorMap)


class NNLS(AmapsError):
    """
    NNLS performs non-negative constrained least
    squares with the abundance nonnegative constraint (ANC).
    Utilizes the method of Bro.
    """

    def __init__(self):
        AmapsError.__init__(self, 'NNLS')
        self.amap = None

    def map(self, M, U, normalize=False):
        """
        NNLS performs non-negative constrained least squares of each pixel
        in M using the endmember signatures of U.

        Parameters:
       	  M: `numpy array`
             A HSI cube (m x n x p).

          U: `numpy array`
             A spectral library of endmembers (q x p).

          normalize: `boolean [default False]`
             If True, M and U are normalized before doing the signals mapping.

        Returns: `numpy array`
              An abundance maps (m x n x q).
        """
        self.validate(M, U)
        h,w,numBands = M.shape
        if normalize == True:
            M = fmt.normalize(M)
            U = fmt.normalize(U)
        Mr = np.reshape(M, (w*h, numBands))
        amap2D = amaps.NNLS(Mr, U)
        self.amap = np.reshape(amap2D, (h, w, U.shape[0]))
        return self.amap

    def plot(self, path, colorMap='jet', suffix=None):
        """
        Plot the abundance maps.

        Parameters:
            path: `string`
              The path where to put the plot.

            colorMap: `string [default jet]`
              A matplotlib color map.

            suffix: `string [default None]`
              Suffix to add to the file name.
        """
        self.validate_plot_input(suffix, 'plot')
        _plot_abundance_map(path, self.amap, 'NNLS', suffix, colorMap)

    def display(self, colorMap='jet', suffix=None):
        """
        Plot the abundance maps on a IPython Notebook.

        Parameters:
            colorMap: `string [default jet]`
              A matplotlib color map.

            suffix: `string [default None]`
              Suffix to add to the title.
        """
        _display_abundance_map(self.amap, 'NNLS', suffix, cmap=colorMap)


class FCLS(AmapsError):
    """
    Performs fully constrained least squares. Fully constrained least squares
    is least squares with the abundance sum-to-one constraint (ASC) and the
    abundance nonnegative constraint (ANC).
    """

    def __init__(self):
        AmapsError.__init__(self, 'FCLS')
        self.amap = None

    def map(self, M, U, normalize=False):
        """
        Performs fully constrained least squares of each pixel in M
        using the endmember signatures of U.

        Parameters:
       	  M: `numpy array`
             A HSI cube (m x n x p).

          U: `numpy array`
             A spectral library of endmembers (q x p).

          normalize: `boolean [default False]`
             If True, M and U are normalized before doing the signals mapping.

        Returns: `numpy array`
              An abundance maps (m x n x q).
        """
        self.validate(M, U)
        h,w,numBands = M.shape
        if normalize == True:
            M = fmt.normalize(M)
            U = fmt.normalize(U)
        Mr = np.reshape(M, (w*h, numBands))
        amap2D = amaps.FCLS(Mr, U)
        self.amap = np.reshape(amap2D, (h, w, U.shape[0]))
        return self.amap

    def plot(self, path, colorMap='jet', suffix=None):
        """
        Plot the abundance maps.

        Parameters:
            path: `string`
              The path where to put the plot.

            colorMap: `string [default jet]`
              A matplotlib color map.

            suffix: `string [default None]`
              Suffix to add to the file name.
        """
        self.validate_plot_input(suffix, 'plot')
        _plot_abundance_map(path, self.amap, 'FCLS', suffix, colorMap)

    def display(self, colorMap='jet', suffix=None):
        """
        Plot the abundance maps on a IPython Notebook.

        Parameters:
            colorMap: `string [default jet]`
              A matplotlib color map.

            suffix: `string [default None]`
              Suffix to add to the title.
        """
        _display_abundance_map(self.amap, 'FCLS', suffix, cmap=colorMap)
