"""
Plot a quartz class map for a drill core HSI cube.
"""

import os
import os.path as osp
import json
import matplotlib.pyplot as plt
import numpy as np

import pysptools.util as util
import pysptools.eea as eea
import pysptools.abundance_maps as amp


def get_endmembers(data, header, q, path):
    print('Endmembers extraction with NFINDR')
    ee = eea.NFINDR()
    U = ee.extract(data, q, maxit=5, normalize=True, ATGP_init=True)
    ee.plot(path, header)
    return U


def gen_abundance_maps(data, U, result_path):
    print('Abundance maps with FCLS')
    fcls = amp.FCLS()
    amap = fcls.map(data, U, normalize=True)
    fcls.plot(result_path, colorMap='jet')
    return amap


def plot(image, colormap, desc, path):
    plt.ioff()
    img = plt.imshow(image)
    img.set_cmap(colormap)
    plt.colorbar()
    fout = osp.join(path, 'plot_{0}.png'.format(desc))
    plt.savefig(fout)
    plt.clf()


if __name__ == '__main__':
    data_path = '../data1'
    project_path = '../'
    result_path = osp.join(project_path, 'results')
    sample = 'hematite'

    fin = osp.join(data_path, sample)
    if osp.exists(result_path) == False:
        os.makedirs(result_path)

    # load the cube
    data_file = osp.join(data_path, sample+'.jdata')
    with open(data_file, 'r') as content_file:
        data = np.array(json.loads(content_file.read()))
    info_file = osp.join(data_path, sample+'.jhead')
    with open(info_file, 'r') as content_file:
        header = json.loads(content_file.read())

    # Telops cubes are flipped left-right
    # Flipping them again restore the orientation
    data = np.fliplr(data)

    U = get_endmembers(data, header, 4, result_path)
    amaps = gen_abundance_maps(data, U, result_path)

    # EM3 == quartz
    quartz = amaps[:,:,1]
    plot(quartz, 'spectral', 'quartz', result_path)

    # EM2 == background, we use the backgroud to isolate the drill core
    # and define the mask
    mask = (amaps[:,:,3] < 0.2)
    plot(mask, 'spectral', 'mask', result_path)

    # pixels sum
    rock_surface = np.sum(mask)
    quartz_surface = np.sum(quartz > 0.16)
    print('Some statistics')
    print('  Drill core surface (mask) in pixels:', rock_surface)
    print('  Quartz surface in pixels:', quartz_surface)
    print('  Hematite surface in pixels:', rock_surface - quartz_surface)
