#
#------------------------------------------------------------------------------
# Copyright (c) 2013-2014, Christian Therien
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#------------------------------------------------------------------------------
#
# detect_int.py - This file is part of the PySptools package.
#


"""
MatchedFilter, ACE, CEM, GLRT, OSP classes
"""


import numpy as np
from . import detect


class Check(object):
    """ Validate inputs for the detection classes """

    err1 = 'in {0}.detect(), M is not a numpy.array'
    err2 = 'in {0}.detect(), t is not a numpy.array'
    err3 = 'in {0}.detect(), M have {1} dimension(s), expected 3 dimensions'
    err4 = 'in {0}.detect(), t have {1} dimension(s), expected 1 dimensions'
    err5 = 'in {0}.detect(), threshold have {1}, expected float'
    err7 = 'in {0}.detect(), threshold value is {1}, expected value between 0.0 and 1.0'
    err10 = 'in {0} class, call detect before calling {1}'
    err11 = 'in {0} class, suffix is not of str type'
    err12 = 'in {0}.detect(), the M signal length is different to the t signal length'

    def __init__(self, label):
        self.label = label

    def plot_input(self, target_map, suffix, method_name):
        if target_map == None:
            raise RuntimeError(self.err10.format(self.label, method_name))
        if type(suffix) is not str and suffix != None:
            raise TypeError(self.err11.format(self.label))

    def detect_input(self, M, t, threshold):
        if type(M) is not np.ndarray:
            raise RuntimeError(self.err1.format(self.label))
        if type(t) is not np.ndarray:
            raise RuntimeError(self.err2.format(self.label))
        if M.ndim != 3:
            raise RuntimeError(self.err3.format(self.label, M.ndim))
        if t.ndim != 1:
            raise RuntimeError(self.err4.format(self.label, t.ndim))
        if type(threshold) is not float and threshold != None:
            raise TypeError(self.err5.format(self.label, type(threshold)))
        if type(threshold) is float:
            if threshold < 0.0 or threshold > 1.0:
                raise ValueError(self.err7.format(self.label, threshold))
        if M.shape[2] != t.shape[0]:
            raise RuntimeError(self.err12.format(self.label))


def _plot_target_map(path, tmap, map_type, whiteOnBlack, suffix=None):
    """ Plot a target map using matplotlib """
    import matplotlib.pyplot as plt
    import os.path as osp
    plt.ioff()
    img = plt.imshow(tmap)
    if whiteOnBlack == True:
        img.set_cmap('Greys_r')
    elif whiteOnBlack == False:
        img.set_cmap('Greys')
    else:
        # throw an error?
        img.set_cmap('Blues')
    if suffix == None:
        fout = osp.join(path, 'tmap_{0}.png'.format(map_type))
    else:
        fout = osp.join(path, 'tmap_{0}_{1}.png'.format(map_type, suffix))
    try:
        plt.savefig(fout)
    except IOError:
        raise IOError('in _plot_target_map, no such file or directory: {0}'.format(path))
    plt.clf()


def _display(tmap, map_type, whiteOnBlack, suffix):
    """ Display a target map using matplotlib
        for the IPython Notebook. """
    import matplotlib.pyplot as plt
    img = plt.imshow(tmap)
    if whiteOnBlack == True:
        img.set_cmap('Greys_r')
    elif whiteOnBlack == False:
        img.set_cmap('Greys')
    else:
        # throw an error?
        img.set_cmap('Blues')
    if suffix == None:
        plt.title('{0} Target Map'.format(map_type))
    else:
        plt.title('{0} Target Map - {1}'.format(map_type, suffix))
    plt.show()
    plt.clf()


class MatchedFilter(object):
    """
    Performs the matched filter algorithm for target detection.
    """

    def __init__(self):
        self.check = Check('MatchedFilter')
        self.target_map = None

    def detect(self, M, t, threshold=None):
        """
        Parameters:
          M: `numpy array`
            A HSI cube (m x n x p).

          t: `numpy array`
            A target pixel (p).

        Returns: `numpy array`
            Vector of detector output (m x n x 1).

        References:
            X Jin, S Paswater, H Cline.  "A Comparative Study of Target Detection
            Algorithms for Hyperspectral Imagery."  SPIE Algorithms and Technologies
            for Multispectral, Hyperspectral, and Ultraspectral Imagery XV.  Vol
            7334.  2009.
        """
        self.check.detect_input(M, t, threshold)
        h,w,numBands = M.shape
        Mr = np.reshape(M, (w*h, numBands))
        target = detect.MatchedFilter(Mr, t)
        self.target_map = np.reshape(target, (h, w))
        if threshold != None:
            self.target_map = self.target_map > threshold
        return self.target_map

    def plot(self, path, whiteOnBlack=True, suffix=None):
        """
        Plot the target map.

        Parameters:
            path: `string`
              The path where to put the plot.

            whiteOnBlack: `boolean [default True]`
              By default, whiteOnBlack=True, the detected signal
              is white on a black background. You can invert this with
              whiteOnBlack=False.

            suffix: `string [default None]`
              Suffix to add to the file name.
        """
        self.check.plot_input(self.target_map, suffix, 'plot')
        _plot_target_map(path, self.target_map, 'MatchedFilter', whiteOnBlack, suffix)

    def display(self, whiteOnBlack=True, suffix=None):
        """
        Display the target map to a IPython Notebook.

        Parameters:
            whiteOnBlack: `boolean [default True]`
              By default, whiteOnBlack=True, the detected signal
              is white on a black background. You can invert this with
              whiteOnBlack=False.

            suffix: `string [default None]`
              Suffix to add to the title.
        """
        self.check.plot_input(self.target_map, suffix, 'display')
        _display(self.target_map, 'MatchedFilter', whiteOnBlack, suffix)


class ACE(object):
    """
    Performs the adaptive cosin/coherent estimator algorithm for target
    detection.
    """

    def __init__(self):
        self.check = Check('ACE')
        self.target_map = None

    def detect(self, M, t, threshold=None):
        """
        Parameters:
          M: `numpy array`
            A HSI cube (m x n x p).

          t: `numpy array`
            A target pixel (p).

        Returns: `numpy array`
            Vector of detector output (m x n x 1).

        References:
          X Jin, S Paswater, H Cline.  "A Comparative Study of Target Detection
          Algorithms for Hyperspectral Imagery."  SPIE Algorithms and Technologies
          for Multispectral, Hyperspectral, and Ultraspectral Imagery XV.  Vol
          7334.  2009.
        """
        self.check.detect_input(M, t, threshold)
        h,w,numBands = M.shape
        Mr = np.reshape(M, (w*h, numBands))
        target = detect.ACE(Mr, t)
        self.target_map = np.reshape(target, (h, w))
        if threshold != None:
            self.target_map = self.target_map > threshold
        return self.target_map

    def plot(self, path, whiteOnBlack=True, suffix=None):
        """
        Plot the target map.

        Parameters:
            path: `string`
              The path where to put the plot.

            whiteOnBlack: `boolean [default True]`
              By default, whiteOnBlack=True, the detected signal
              is white on a black background. You can invert this with
              whiteOnBlack=False.

            suffix: `string [default None]`
              Suffix to add to the file name.
        """
        self.check.plot_input(self.target_map, suffix, 'plot')
        _plot_target_map(path, self.target_map, 'ACE', whiteOnBlack, suffix)

    def display(self, whiteOnBlack=True, suffix=None):
        """
        Display the target map to a IPython Notebook.

        Parameters:
            whiteOnBlack: `boolean [default True]`
              By default, whiteOnBlack=True, the detected signal
              is white on a black background. You can invert this with
              whiteOnBlack=False.

            suffix: `string [default None]`
              Suffix to add to the title.
        """
        self.check.plot_input(self.target_map, suffix, 'display')
        _display(self.target_map, 'ACE', whiteOnBlack, suffix)


class CEM(object):
    """
    Performs the constrained energy minimization algorithm for target
    detection.
    """

    def __init__(self):
        self.check = Check('CEM')
        self.target_map = None

    def detect(self, M, t, threshold=None):
        """
        Parameters:
          M: `numpy array`
            A HSI cube (m x n x p).

          t: `numpy array`
            A target pixel (p).

        Returns: `numpy array`
            Vector of detector output (m x n x 1).

        References:
            Qian Du, Hsuan Ren, and Chein-I Cheng. A Comparative Study of
            Orthogonal Subspace Projection and Constrained Energy Minimization.
            IEEE TGRS. Volume 41. Number 6. June 2003.
        """
        self.check.detect_input(M, t, threshold)
        h,w,numBands = M.shape
        Mr = np.reshape(M, (w*h, numBands))
        target = detect.CEM(Mr, t)
        self.target_map = np.reshape(target, (h, w))
        if threshold != None:
            self.target_map = self.target_map > threshold
        return self.target_map

    def plot(self, path, whiteOnBlack=True, suffix=None):
        """
        Plot the target map.

        Parameters:
            path: `string`
              The path where to put the plot.

            whiteOnBlack: `boolean [default True]`
              By default, whiteOnBlack=True, the detected signal
              is white on a black background. You can invert this with
              whiteOnBlack=False.

            suffix: `string [default None]`
              Suffix to add to the file name.
        """
        self.check.plot_input(self.target_map, suffix, 'plot')
        _plot_target_map(path, self.target_map, 'CEM', whiteOnBlack, suffix)

    def display(self, whiteOnBlack=True, suffix=None):
        """
        Display the target map to a IPython Notebook.

        Parameters:
            whiteOnBlack: `boolean [default True]`
              By default, whiteOnBlack=True, the detected signal
              is white on a black background. You can invert this with
              whiteOnBlack=False.

            suffix: `string [default None]`
              Suffix to add to the title.
        """
        self.check.plot_input(self.target_map, suffix, 'display')
        _display(self.target_map, 'CEM', whiteOnBlack, suffix)


class GLRT(object):
    """
    Performs the generalized likelihood test ratio algorithm for target
    detection.
    """

    def __init__(self):
        self.check = Check('GLRT')
        self.target_map = None

    def detect(self, M, t, threshold=None):
        """
        Parameters:
          M: `numpy array`
            A HSI cube (m x n x p).

          t: `numpy array`
            A target pixel (p).

        Returns: `numpy array`
            Vector of detector output (m x n x 1).

        References
            T. F. AyouB, "Modified GLRT Signal Detection Algorithm," IEEE
            Transactions on Aerospace and Electronic Systems, Vol 36, No 3, July
            2000.
        """
        self.check.detect_input(M, t, threshold)
        h,w,numBands = M.shape
        Mr = np.reshape(M, (w*h, numBands))
        target = detect.GLRT(Mr, t)
        self.target_map = np.reshape(target, (h, w))
        if threshold != None:
            self.target_map = self.target_map > threshold
        return self.target_map

    def plot(self, path, whiteOnBlack=True, suffix=None):
        """
        Plot the target map.

        Parameters:
            path: `string`
              The path where to put the plot.

            whiteOnBlack: `boolean [default True]`
              By default, whiteOnBlack=True, the detected signal
              is white on a black background. You can invert this with
              whiteOnBlack=False.

            suffix: `string [default None]`
              Suffix to add to the file name.
        """
        self.check.plot_input(self.target_map, suffix, 'plot')
        _plot_target_map(path, self.target_map, 'GLRT', whiteOnBlack, suffix)

    def display(self, whiteOnBlack=True, suffix=None):
        """
        Display the target map to a IPython Notebook.

        Parameters:
            whiteOnBlack: `boolean [default True]`
              By default, whiteOnBlack=True, the detected signal
              is white on a black background. You can invert this with
              whiteOnBlack=False.

            suffix: `string [default None]`
              Suffix to add to the title.
        """
        self.check.plot_input(self.target_map, suffix, 'display')
        _display(self.target_map, 'GLRT', whiteOnBlack, suffix)


class OSP(object):
    """
    Performs the othogonal subspace projection algorithm for target
    detection.
    """

    def __init__(self):
        self.check = Check('OSP')
        self.target_map = None

    def detect(self, M, E, t, threshold=None):
        """
        Parameters:
          M: `numpy array`
            A HSI cube (m x n x p).

          E: `numpy array`
            Background pixels (n x p).

          t: `numpy array`
            A target pixel (p).

        Returns: `numpy array`
            Vector of detector output (m x n x 1).

        References:
            Qian Du, Hsuan Ren, and Chein-I Cheng. "A Comparative Study of
            Orthogonal Subspace Projection and Constrained Energy Minimization."
            IEEE TGRS. Volume 41. Number 6. June 2003.
        """
        self.check.detect_input(M, t, threshold)
        h,w,numBands = M.shape
        Mr = np.reshape(M, (w*h, numBands))
        target = detect.OSP(Mr, E, t)
        self.target_map = np.reshape(target, (h, w))
        if threshold != None:
            self.target_map = self.target_map > threshold
        return self.target_map

    def plot(self, path, whiteOnBlack=True, suffix=None):
        """
        Plot the target map.

        Parameters:
            path: `string`
              The path where to put the plot.

            whiteOnBlack: `boolean [default True]`
              By default, whiteOnBlack=True, the detected signal
              is white on a black background. You can invert this with
              whiteOnBlack=False.

            suffix: `string [default None]`
              Suffix to add to the file name.
        """
        self.check.plot_input(self.target_map, suffix, 'plot')
        _plot_target_map(path, self.target_map, 'OSP', whiteOnBlack, suffix)

    def display(self, whiteOnBlack=True, suffix=None):
        """
        Display the target map to a IPython Notebook.

        Parameters:
            whiteOnBlack: `boolean [default True]`
              By default, whiteOnBlack=True, the detected signal
              is white on a black background. You can invert this with
              whiteOnBlack=False.

            suffix: `string [default None]`
              Suffix to add to the title.
        """
        self.check.plot_input(self.target_map, suffix, 'display')
        _display(self.target_map, 'OSP', whiteOnBlack, suffix)
