"""
Plot smokestack effluents.
"""


import os
import os.path as osp
import pysptools.classification as cls
import matplotlib.pyplot as plt
import pysptools.util as util
import pysptools.eea as eea
import pysptools.abundance_maps as amp

import numpy as np


class Classify(object):
    """
    For this problem NormXCorr works as well as SAM
    SID was not tested.
    """

    def __init__(self, data, E, path, threshold, suffix):
        print 'Classify using SAM'
        self.sam = cls.SAM()
        self.sam.classify(data, E, threshold=threshold)
        self.path = path
        self.suffix = suffix

    def get_single_map(self, idx):
        return self.sam.get_single_map(idx, constrained=False)

    def plot_single_map(self, idx):
        self.sam.plot_single_map(self.path, idx, constrained=False, suffix=self.suffix)


def get_endmembers(data, header, q, path, mask, suffix, output=False):
    print 'Endmembers extraction with NFINDR'
    ee = eea.NFINDR()
    U = ee.extract(data, q, maxit=5, normalize=True, ATGP_init=True, mask=mask)
    if output == True:
        ee.plot(path, header, suffix=suffix)
    return U


def get_abundance_maps(data, U, umix_source, path, output=False):
    print 'Abundance maps with FCLS'
    fcls = amp.FCLS()
    amap = fcls.map(data, U, normalize=True)
    if output == True:
        fcls.plot(path, colorMap='jet', suffix=umix_source)
    return amap


def get_endmembers_sets(data, header, path):
    """ Return a endmembers set for the full cube and a
        endmembers set for the region of interest (ROI).
        The ROI is created using a small region of the
        effluents leaving near the smokestack.
    """
    # Take the endmembers set for all the cube
    U_full_cube = get_endmembers(data, header, 8, path, None, 'full_cube', output=True)
    # A threshold of 0.15 give a good ROI
    cls = Classify(data, U_full_cube, path, 0.15, 'full_cube')
    # The endmember EM2 is use to define the region of interest
    cls.plot_single_map(2)
    # The effluents region of interest
    effluents = cls.get_single_map(2)
    # Create the binary mask with the effluents
    mask = (effluents > 0)
    # Plot the mask
    plot(mask, 'gray', 'binary_mask', path)
    # And use this mask to extract endmembers near the smokestack exit
    U_masked = get_endmembers(data, header, 8, path, mask, 'masked', output=True)
    return U_full_cube, U_masked


def classification_analysis(data, header, path, E_full_cube, E_masked):
    # Classify with the masked endmembers set
    cls = Classify(data, E_masked, path, 0.15, 'masked')
    # and plot the results
    cls.plot_single_map('all')
    # Calculate the average image
    # 0 to 6, the last image, number 7, is a background, we skip it
    for i in range(7):
        if i == 0:
            gas = cls.get_single_map(i+1)
        gas = gas + cls.get_single_map(i+1)
    gas = gas / 7
    # and plot it
    plot(gas, 'spectral', 'mean_SAM', path)


def unmixing_analysis(data, header, path, E_full_cube, E_masked):
    # Calculate the average image, but use a trick
    # The last image, number 7, is a background, we skip it
    for i in range(7):
        E_full_cube[1,:] = E_masked[i,:]
        amaps = get_abundance_maps(data, E_full_cube, 'masqued_{0}'.format(i+1), path, output=False)
        if i == 0:
            mask = amaps[:,:,1]
        else:
            mask = mask + amaps[:,:,1]
        thresholded = (amaps[:,:,1] > 0.3) * amaps[:,:,1]
        plot(thresholded, 'spectral', 'masqued_{0}'.format(i+1), path)
    mask = mask / 7
    thresholded = (mask > 0.3) * mask
    plot(thresholded, 'spectral', 'mean_FCLS', path)


def plot(image, colormap, desc, path):
    plt.ioff()
    img = plt.imshow(image)
    img.set_cmap(colormap)
    plt.colorbar()
    fout = osp.join(path, 'plot_{0}.png'.format(desc))
    plt.savefig(fout)
    plt.clf()


if __name__ == '__main__':
    plt.ioff()
    data_path = '../data1'
    project_path = '../'
    result_path = osp.join(project_path, 'results')
    sample = 'Smokestack1.hdr'

    fin = osp.join(data_path, sample)
    if osp.exists(result_path) == False:
        os.makedirs(result_path)

    data_file = osp.join(data_path, sample)
    data, header = util.load_ENVI_file(data_file)
    # Telops cubes are flipped left-right
    # Flipping them again restore the orientation
    data = np.fliplr(data)

    U_full_cube, U_masked = get_endmembers_sets(data, header, result_path)
    classification_analysis(data, header, result_path, U_full_cube, U_masked)
    unmixing_analysis(data, header, result_path, U_full_cube, U_masked)
