#
#------------------------------------------------------------------------------
# Copyright (c) 2013-2014, Christian Therien
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#------------------------------------------------------------------------------
#
# utests.py - This file is part of the PySptools package.
#

import unittest
import numpy as np
from pysptools.noise import SavitzkyGolay, Whiten, MNF
from types import *

# B suffix == bad
# G suffix == good

class TestSavitzkyGolayError(unittest.TestCase):

    def setUp(self):
        pass

    def runTest(self):
        print('==> runTest: TestSavitzkyGolayError')
        self.test_SavitzkyGolay()
        self.test_plot()


    def test_SavitzkyGolay(self):
        data_B1 = np.zeros((2,2,2,2))
        data_B2 = [1,2]
        data_G = np.zeros((2,2,2))

        sg = SavitzkyGolay()
        # err1
        with self.assertRaises(TypeError):
            sg.denoise_spectra(data_B2,3,1)
        # err1
        with self.assertRaises(TypeError):
            sg.denoise_bands(data_B2,3,1)
        # err2
        with self.assertRaises(RuntimeError):
            sg.denoise_spectra(data_B1,3,1)
        # err2
        with self.assertRaises(RuntimeError):
            sg.denoise_bands(data_B1,3,1)
        # err3
        with self.assertRaises(TypeError):
            sg.denoise_spectra(data_G,'string',1)
        with self.assertRaises(TypeError):
            sg.denoise_spectra(data_G,3,'string')
        with self.assertRaises(TypeError):
            sg.denoise_spectra(data_G,3,0,'string')
        with self.assertRaises(TypeError):
            sg.denoise_spectra(data_G,3,0,1,'string')
        # err3
        with self.assertRaises(TypeError):
            sg.denoise_bands(data_G,'string',1)
        with self.assertRaises(TypeError):
            sg.denoise_bands(data_G,3,'string')
        with self.assertRaises(TypeError):
            sg.denoise_bands(data_G,3,0,33.3)


    def test_plot(self):
        sg = SavitzkyGolay()
        # err10
        with self.assertRaises(RuntimeError):
            sg.plot_bands_sample('../results', 1)
        # err11
        sg.denoised = np.array([1,2])
        sg.dbands = np.array([1,2])
        with self.assertRaises(TypeError):
            sg.plot_bands_sample('../results', 1, suffix=[1,2])


class TestWhitenError(unittest.TestCase):

    def setUp(self):
        pass

    def runTest(self):
        print('==> runTest: TestWhitenError')
        self.test_validate1()


    def test_validate1(self):
        data_B1 = np.zeros((2,2,2,2))
        data_B2 = [1,2]

        sg = SavitzkyGolay()
        # err1
        with self.assertRaises(TypeError):
            sg.denoise_spectra(data_B1)
        with self.assertRaises(TypeError):
            sg.denoise_spectra(data_B2)


class TestMNF(unittest.TestCase):

    def setUp(self):
        pass

    def runTest(self):
        print('==> runTest: TestMNF')
        self.test_validate1()
        self.test_plot()


    def test_validate1(self):
        data_B1 = np.zeros((2,2,2,2))
        data_B2 = [1,2]

        sg = MNF()
        # err1
        with self.assertRaises(RuntimeError):
            sg.apply(data_B1)
        with self.assertRaises(TypeError):
            sg.apply(data_B2)


    def test_plot(self):
        m = MNF()
        # err10
        with self.assertRaises(RuntimeError):
            m.plot_components('../results', 1)
        # err11
        m.mnf = np.array([1,2])
        with self.assertRaises(TypeError):
            m.plot_components('../results', 1, suffix=[1,2])


if __name__ == '__main__':
    unittest.main()

