"""
Python Spectral Tools

The following functions and classes are tested:
    SAM
    SID
    NormXCorr
    NFINDR
    FIPPI
    ATGP
"""



import os
import os.path as osp
import cProfile, pstats
import pysptools.classification as cls
import numpy as np
import pysptools.eea as eea


_doProfile = False
def profile():
    if _doProfile == True:
        pr = cProfile.Profile()
        pr.enable()
        return pr

def stat(pr):
    if _doProfile == True:
        pr.disable()
        ps = pstats.Stats(pr)
        ps.strip_dirs()
        ps.sort_stats('time')
        ps.print_stats()


def test_SID(data, E, result_path):
    print('Testing SID')
    sid = cls.SID()
    pr = profile()
    sid.classify(data, E)
    stat(pr)
    sid.plot_single_map(result_path, 'all', suffix='t1')
    sid.plot(result_path, colorMap='Set3', suffix='t1')
    cmap = sid.classify(data, E, threshold=0.05)
    sid.plot_single_map(result_path, 'all', suffix='t2')
    sid.plot(result_path, colorMap='Set3', suffix='t2')
    cmap = sid.classify(data, E, threshold=[0.1,0.1,0.05,0.1,0.1])
    sid.get_single_map(1)
    sid.get_SID_map()
    sid.plot_single_map(result_path, 'all', suffix='t3')
    sid.plot(result_path, colorMap='Set3', suffix='t3')
    sid.plot_histo(result_path)

    sid = cls.SID()
    cmap = sid.classify(data, E[1,:])
    sid.plot(result_path, colorMap='Set3', suffix='t4')


def test_SAM(data, E, result_path):
    print('Testing SAM')
    sam = cls.SAM()
    pr = profile()
    cmap = sam.classify(data, E)
    stat(pr)
    sam.plot_single_map(result_path, 'all', suffix='t1')
    sam.plot(result_path, colorMap='Set3', suffix='t1')
    cmap = sam.classify(data, E, threshold=0.25)
    sam.plot_single_map(result_path, 'all', suffix='t2')
    sam.plot(result_path, colorMap='Set3', suffix='t2')
    cmap = sam.classify(data, E, threshold=[0.1,0.1,0.05,0.1,0.1])
    sam.get_single_map(1)
    sam.get_angles_map()
    sam.get_angles_stats()
    sam.plot_single_map(result_path, 'all', suffix='t3')
    sam.plot(result_path, colorMap='Set3', suffix='t3')
    sam.plot_histo(result_path)

    sam = cls.SAM()
    cmap = sam.classify(data, E[1,:])
    sam.plot(result_path, colorMap='Set3', suffix='t4')


def test_NormXCorr(data, E, result_path):
    print('Testing NormXCorr')
    xc = cls.NormXCorr()
    pr = profile()
    cmap = xc.classify(data, E)
    stat(pr)
    xc.plot_single_map(result_path, 'all', suffix='t1')
    xc.plot(result_path, colorMap='Set3', suffix='t1')
    cmap = xc.classify(data, E, threshold=0.15)
    xc.plot_single_map(result_path, 'all', suffix='t2')
    xc.plot(result_path, colorMap='Set3', suffix='t2')
    cmap = xc.classify(data, E, threshold=[0.1,0.1,0.05,0.1,0.1])
    xc.get_single_map(1)
    xc.get_NormXCorr_map()
    xc.plot_single_map(result_path, 'all', suffix='t3')
    xc.plot(result_path, colorMap='Set3', suffix='t3')
    xc.plot_histo(result_path)

    xc = cls.NormXCorr()
    cmap = xc.classify(data, E[1,:])
    xc.plot(result_path, colorMap='Set3', suffix='t4')


def test_one_spectrum(data, E, result_path):
    # classify one pixel
    print('Testing SAM one spectrum')
    for i in range(E.shape[0]):
        cl = cls.SAM()
        cl.classify(data, E[i])
        cl.plot(result_path, suffix='t5_{0}'.format(i+1))
    print('Testing SID one spectrum')
    for i in range(E.shape[0]):
        cl = cls.SID()
        cl.classify(data, E[i])
        cl.plot(result_path, suffix='t5_{0}'.format(i+1))
    print('Testing NormXCorr one spectrum')
    for i in range(E.shape[0]):
        cl = cls.NormXCorr()
        cl.classify(data, E[i])
        cl.plot(result_path, suffix='t5_{0}'.format(i+1))


def test_SAM_and_NFINDR(data, result_path, info):
    """
    Give an idea of the endmembers distribution
    over the HSI cube
    """
    print('Testing NFINDR and SAM')
    import timeit
    findr = eea.NFINDR()
    pr = profile()
    U = findr.extract(data, 18, maxit=10, normalize=True, ATGP_init=True)
    stat(pr)
    findr.plot(result_path, info)
    print('  Iterations:', findr.get_iterations())
    sam = cls.SAM()
    #sam.classify(data, U, threshold=0.1)
    # ajusted for SAMSON_part, 8 endmembers
    #sam.classify(data, U, threshold=[0.25,0.4,0.1,0.15,0.2,0.1,0.08,0.1])
    sam.classify(data, U)
    sam.plot_single_map(result_path, 'all', constrained=True, suffix='t4')
    sam.plot(result_path, colorMap='Paired', suffix='t4')
    sam.plot_histo(result_path, suffix='t4')


def test_SID_and_NFINDR(data, result_path, info):
    print('Testing NFINDR and SID')
    import timeit
    findr = eea.NFINDR()
    pr = profile()
    U = findr.extract(data, 12, maxit=5, normalize=True, ATGP_init=True)
    stat(pr)
    findr.plot(result_path, info)
    print('  Iterations:', findr.get_iterations())
    sid = cls.SID()
    #sid.classify(data, U, threshold=None)
    sid.classify(data, U, threshold=0.1)
    sid.plot_single_map(result_path, 'all', suffix='t4')
    sid.plot(result_path, colorMap='', suffix='t4')
    sid.plot_histo(result_path, suffix='t4')


def test_NormXCorr_and_NFINDR(data, result_path, info):
    print('Testing NFINDR and NormXCorr')
    import timeit
    import matplotlib.pyplot as plt
    import numpy as np
    findr = eea.NFINDR()
    pr = profile()
    U = findr.extract(data, 12, maxit=10, normalize=True, ATGP_init=True)
    stat(pr)
    findr.plot(result_path, info)
    print('  Iterations:', findr.get_iterations())
    corr = cls.NormXCorr()
    cmap = corr.classify(data, U, threshold=0.2)
    corr.plot_histo(result_path, suffix='t4')
    corr.plot_single_map(result_path, 'all', suffix='t4')
    corr.plot(result_path, colorMap='', suffix='t4')


def test_SAM_and_FIPPI(data, result_path, info):
    print('Testing SAM and FIPPI')
    fippi = eea.FIPPI()
    pr = profile()
    U = fippi.extract(data, 12, 6, normalize=True)
    fippi.plot(result_path, info)
    sam = cls.SAM()
    cmap = sam.classify(data, U)
    stat(pr)
    sam.plot_single_map(result_path, 'all', suffix='FIPPI')
    sam.plot(result_path, suffix='FIPPI')
    sam.plot_histo(result_path, suffix='FIPPI')


def test_SAM_and_ATGP(data, result_path, info):
    print('Testing ATGP and SAM')
    import timeit
    atgp = eea.ATGP()
    pr = profile()
    U = atgp.extract(data, 16, normalize=True)
    stat(pr)
    atgp.plot(result_path, info)
    sam = cls.SAM()
    sam.classify(data, U, threshold=0.1)
    #sam.classify(data, U)
    sam.plot_single_map(result_path, 'all', suffix='ATGP')
    sam.plot(result_path, colorMap='', suffix='ATGP')
    sam.plot_histo(result_path, suffix='ATGP')


def tests_P2_7():
    import pysptools.util as util
    data_path = '../data1'
    project_path = '../'
    result_path = osp.join(project_path, 'results')
    sample = 'SAMSON_part.hdr'
    #sample = '92AV3C.hdr'
    spec_lib_path = '../data'
    spec_lib_hdr = 'speclib1.hdr'
    if osp.exists(result_path) == False:
        os.makedirs(result_path)

    data_file = osp.join(data_path, sample)
    data, info = util.load_ENVI_file(data_file)
    lib_file = osp.join(spec_lib_path, spec_lib_hdr)
    E, E_info = util.load_ENVI_spec_lib(lib_file)

    test_SID(data, E, result_path)
    test_SAM(data, E, result_path)
    test_NormXCorr(data, E, result_path)
    test_SAM_and_NFINDR(data, result_path, info)
    test_SID_and_NFINDR(data, result_path, info)
    test_NormXCorr_and_NFINDR(data, result_path, info)
    test_SAM_and_ATGP(data, result_path, info)
    test_SAM_and_FIPPI(data, result_path, info)


def tests_P3_3():
    import json
    data_path = '../data1'
    project_path = '../'
    result_path = osp.join(project_path, 'results')
    sample = 'SAMSON_part'
    #sample = '92AV3C.hdr'
    spec_lib_path = '../data'
    spec_lib_hdr = 'speclib1'
    if osp.exists(result_path) == False:
        os.makedirs(result_path)

    data_file = osp.join(data_path, sample+'.jdata')
    with open(data_file, 'r') as content_file:
        data = np.array(json.loads(content_file.read()))

    info_file = osp.join(data_path, sample+'.jhead')
    with open(info_file, 'r') as content_file:
        info = json.loads(content_file.read())

    lib_file = osp.join(spec_lib_path, spec_lib_hdr+'.jdata')
    with open(lib_file, 'r') as content_file:
        E = np.array(json.loads(content_file.read()))

    test_SID(data, E, result_path)
    test_SAM(data, E, result_path)
    test_NormXCorr(data, E, result_path)
    test_SAM_and_NFINDR(data, result_path, info)
    test_SID_and_NFINDR(data, result_path, info)
    test_NormXCorr_and_NFINDR(data, result_path, info)
    test_SAM_and_ATGP(data, result_path, info)
    test_SAM_and_FIPPI(data, result_path, info)


def tests():
    import sys
    if sys.version_info[:2] == (2,7):
        tests_P2_7()
    if sys.version_info[:2] == (3,3):
        tests_P3_3()

if __name__ == '__main__':
    import sys
    print(sys.version_info)
    if sys.version_info[:2] == (2,7):
        tests_P2_7()
    if sys.version_info[:2] == (3,3):
        tests_P3_3()
