#
#------------------------------------------------------------------------------
# Copyright (c) 2013-2014, Christian Therien
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#------------------------------------------------------------------------------
#
# test_vd.py - This file is part of the PySptools package.
#

"""
The following functions are tested:
    HfcVd
    HySime
"""



import os
import os.path as osp
import numpy as np
import scipy

import pysptools.util as util
import pysptools.material_count as cnt
try:
    # can be imported only with Python 2.7
    import pysptools.spectro as spectro
except ImportError:
    pass
import pysptools.formatting as fmt


def get_random_n_endmembers(lib_name, n):
    import random
    lib = spectro.USGS06SpecLib(lib_name)
    dim = lib.get_dim()
    idx = random.sample(list(range(dim)), n)
    # 224 is the number of bands
    U = np.zeros((224, n), dtype=np.float)
    for i, j in enumerate(idx):
        U[:,i] = lib.get(j)
    # the USGS library sometimes have very small numbers that create numeric
    # instability, normalize get rid of them
    return fmt.normalize(U)


def dirichlet_rnd(A, dim):
    """
    Returns a matrix of random numbers chosen
    from the dirichlet distribution with parameters vector A.

    Parameters:
        A: `numpy array`
            A vector of shape parameters.

    Returns: `numpy array`
        A matrix of random numbers.
    """
    N = A.shape[0]

    x = np.zeros((dim, N), dtype=np.float)
    for i in range(N):
        x[:,i] = scipy.stats.gamma.rvs(A[i], scale=1, size=dim)

    denom = np.sum(x, axis=1)
    for i in range(N):
        x[:,i] = x[:,i] / denom
    return x


def generate_hyperspectral_data(U, p, N):
    """
    Generate a simulated hyperspectral data set.

    Parameters:
        U: `numpy array`
            USGS library subset.

        p: `int`
            Number of endmembers.

        N: `int`
            Number of pixels.

    Returns: (`numpy array`, `numpy array`)
        * x is the signal (endmembers linear mixture)
        * s abundance fractions (Nxp)
    """
    numBands, p1 = U.shape
    s = dirichlet_rnd(np.ones(p)/p, N)
    # linear mixture:
    x = np.dot(U, s.T)
    return x.T, s


def test_synthetic_hypercube():
    """
    Test a synthetic hypercube made with p endmembers taken to the USGS library.
    Maybe the USGS library is not a good source of endmembers. Same mineral
    species can have a very similar signature. Picking a random subset can
    return some nearly identical signatures.

    There is no noise added.

    In general, the results are good for both HySime and HfcVd for small values of p.
    """
    print('Testing synthetic hypercube')
    data_path = '../usgs'
    project_path = '../'

    # USGS library
    hdr_name = 's06av95a_envi.hdr'
    lib_name = os.path.join(data_path, hdr_name)

    # number of endmembers
    p = 4
    # get a library of endmembers
    U = get_random_n_endmembers(lib_name, p)
    # cube dimension
    x_coord = 100
    y_coord = 100
    # number of pixels
    N = x_coord*y_coord

    y, s = generate_hyperspectral_data(U, p, N);
    # y is a vector of pixels, yr is the equivalent cube
    # 224 is the number of bands
    yr = np.reshape(y, (x_coord, y_coord, 224))
    # calculate kf
    hy = cnt.HySime()
    kf, Ek = hy.count(yr)
    print('  HySime kf:',kf)
    # calculate vd
    hfcvd = cnt.HfcVd()
    vd = hfcvd.count(yr)
    print('  HfcVd vd:',vd)


def test_hysime(data, wvl, path):
    hy = cnt.HySime()
    kf, Ek = hy.count(data)
    print('Testing HySime')
    print('  Virtual dimensionality is: k =', kf)


def test_HfcVd(data):
    hfcvd = cnt.HfcVd()
    print('Testing HfcVd')
    print('  Virtual dimensionality:', hfcvd.count(data))
    print('Testing NWHFC')
    print('  Virtual dimensionality:', hfcvd.count(data, noise_whitening=True))


def tests_P2_7():
    data_path = '../data1'
    project_path = '../'
    result_path = osp.join(project_path, 'results')
    sample = 'SAMSON_mini.hdr'
    #sample = '92AV3C.hdr'

    fin = osp.join(data_path, sample)
    if osp.exists(result_path) == False:
        os.makedirs(result_path)

    data_file = osp.join(data_path, sample)
    data, wvl = util.load_ENVI_file(data_file)
    test_hysime(data, wvl, result_path)
    test_HfcVd(data)
    test_synthetic_hypercube()


def tests_P3_3():
    import json
    data_path = '../data1'
    project_path = '../'
    result_path = osp.join(project_path, 'results')
    sample = 'SAMSON_mini'
    #sample = '92AV3C'

    fin = osp.join(data_path, sample)
    if osp.exists(result_path) == False:
        os.makedirs(result_path)

    # load the cube
    data_file = osp.join(data_path, sample+'.jdata')
    with open(data_file, 'r') as content_file:
        data = np.array(json.loads(content_file.read()))
    info_file = osp.join(data_path, sample+'.jhead')
    with open(info_file, 'r') as content_file:
        info = json.loads(content_file.read())

    test_hysime(data, info, result_path)
    test_HfcVd(data)
    # can't be tested with Python 3.3
    #test_synthetic_hypercube()


def tests():
    import sys
    if sys.version_info[:2] == (2,7):
        tests_P2_7()
    if sys.version_info[:2] == (3,3):
        tests_P3_3()


if __name__ == '__main__':
    import sys
    print(sys.version_info)
    if sys.version_info[:2] == (2,7):
        tests_P2_7()
    if sys.version_info[:2] == (3,3):
        tests_P3_3()
