#
#------------------------------------------------------------------------------
# Copyright (c) 2013-2014, Christian Therien
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#------------------------------------------------------------------------------
#
# test_dnoise.py - This file is part of the PySptools package.
#

"""
The following functions are tested:
    SavitzkyGolay
    Whiten
    MNF
"""



import numpy as np
import os.path as osp
import pysptools.util as util
import pysptools.noise as ns
import pysptools.eea as eea
import pysptools.classifiers as cls


def do_SavitzkyGolay_spectra_filter(data):
    print('Doing SavitzkyGolay spectra filter')
    sg = ns.SavitzkyGolay()
    fdata = sg.denoise_spectra(data, 3, 1)
    return fdata


def do_SavitzkyGolay_bands_filter(data):
    print('Doing SavitzkyGolay bands filter')
    sg = ns.SavitzkyGolay()
    return sg.denoise_bands(data, 3, 1)


def do_whiten(data):
    print('Doing Whiten')
    w = ns.Whiten()
    return w.apply(data)


def do_MNF(data, n_components):
    print('Doing MNF')
    mnf = ns.MNF()
    mnf.apply(data)
    # get the first n_components
    return mnf.get_components(n_components)


def do_MNF_and_clean_noisy_bands_and_inverse(data):
    mnf = ns.MNF()
    tdata = mnf.apply(data)
    h, w, numBands = tdata.shape
    dn = ns.SavitzkyGolay()
    # denoise low variance bands two times
    tdata[:,:,20:] = dn.denoise_bands(tdata[:,:,20:], 3, 1)
    tdata[:,:,20:] = dn.denoise_bands(tdata[:,:,20:], 3, 1)
    # inverse remove the PCA rotation, we obtain a whitened cube with
    # the low variance bands denoised
    return mnf.inverse_transform(tdata)


def test_whiten(data, result_path, info):
    print('Testing whiten with ATGP and SAM')
    print('Denoising')
    wdata = do_whiten(data)
    print('Extracting')
    atgp = eea.ATGP()
    U = atgp.extract(wdata, 24)
    atgp.plot(result_path, info)
    print('Classifying')
    sam = cls.SAM()
    sam.classify(wdata, U, threshold=0.1)
    sam.plot(result_path, colorMap='', suffix='whiten')


def test_MNF(data, result_path, info):
    print('Testing MNF with NFINDR and SID')
    print('Denoising')
    # number of endmembers asked
    n = 24
    # NFINDR needs n - 1 components
    tdata = do_MNF(data, n-1)
    print('Extracting')
    ee = eea.NFINDR()
    # extract from the transformed data
    # U is the endmembers set that you can use to classify or unmix with data
    U = ee.extract(data, n, transform=tdata, normalize=True, ATGP_init=True)
    # Ut is the endmembers set that you can use to classify or unmix with tdata
    Ut = ee.get_endmembers_transform()
    # info need to be ajusted
    ee.plot(result_path, info, suffix='MNF')
    # First we classify with U and data
    print('Classification with U and data')
    sid = cls.SID()
    sid.classify(data, U, threshold=0.1)
    sid.plot(result_path, colorMap='Paired', suffix='MNF_U_data')
    # Next we classify with Ut and tdata
    # *****************************************************************
    # **** this test give the best result for the SAMSON_part cube ****
    print('Classification with Ut and tdata')
    sid = cls.SID()
    sid.classify(tdata, Ut, threshold=0.1)
    sid.plot(result_path, colorMap='Paired', suffix='MNF_Ut_tdata')
    print('Testing MNF+clean noisy bands+inverse with NFINDR and SID')
    print('Denoising')
    idata = do_MNF_and_clean_noisy_bands_and_inverse(data)
    print('Extracting')
    ee = eea.NFINDR()
    U = ee.extract(idata, n, transform=None, normalize=True, ATGP_init=True)
    ee.plot(result_path, info, suffix='MNF_clean_inverse')
    print('Classification with U and idata')
    sid = cls.SID()
    sid.classify(idata, U, threshold=0.1)
    sid.plot(result_path, colorMap='Paired', suffix='MNF_clean_inverse_U_idata')


def test_SavitzkyGolay(data, result_path, info):
    print('Testing SavitzkyGolay bands filter with ATGP and SAM')
    print('Denoising')
    tdata = do_SavitzkyGolay_bands_filter(data)
    print('Extracting')
    ee = eea.ATGP()
    U = ee.extract(tdata, 24, normalize=True)
    ee.plot(result_path, info, suffix='SGbands')
    print('Classification')
    sam = cls.SAM()
    sam.classify(tdata, U, threshold=0.1)
    sam.plot(result_path, colorMap='', suffix='SGbands')
    print('Testing SavitzkyGolay spectra filter with ATGP and SAM')
    print('Denoising')
    tdata = do_SavitzkyGolay_spectra_filter(data)
    print('Extracting')
    ee = eea.ATGP()
    U = ee.extract(tdata, 24, normalize=True)
    ee.plot(result_path, info, suffix='SGspectra')
    print('Classification')
    sam = cls.SAM()
    sam.classify(tdata, U, threshold=0.1)
    sam.plot(result_path, colorMap='', suffix='SGspectra')


def tests_P2_7():
    data_path = r'..\data1'
    project_path = '..\\'
    result_path = osp.join(project_path, 'results')
    sample = 'SAMSON_mini.hdr'

    data_file = osp.join(data_path, sample)
    data, info = util.load_ENVI_file(data_file)

    test_MNF(data, result_path, info)
    test_whiten(data, result_path, info)
    test_SavitzkyGolay(data, result_path, info)


def tests_P3_3():
    import json
    data_path = r'..\data1'
    project_path = '..\\'
    result_path = osp.join(project_path, 'results')
    sample = 'SAMSON_mini'

    # load the cube
    data_file = osp.join(data_path, sample+'.jdata')
    with open(data_file, 'r') as content_file:
        data = np.array(json.loads(content_file.read()))
    info_file = osp.join(data_path, sample+'.jhead')
    with open(info_file, 'r') as content_file:
        info = json.loads(content_file.read())

    test_MNF(data, result_path, info)
    test_whiten(data, result_path, info)
    test_SavitzkyGolay(data, result_path, info)


def tests():
    import sys
    if sys.version_info[:2] == (2,7):
        tests_P2_7()
    if sys.version_info[:2] == (3,3):
        tests_P3_3()


if __name__ == '__main__':
    import sys
    print(sys.version_info)
    if sys.version_info[:2] == (2,7):
        tests_P2_7()
    if sys.version_info[:2] == (3,3):
        tests_P3_3()
