#!/usr/bin/env python

"""Do regression testing.
"""

from tempfile import mktemp
import os
import unittest

import argparse
import numpy as np

from hmmus import demo
from hmmus import hmm
from hmmus import estimation

class TempFileManager:
    def __init__(self):
        self.redefine_filenames()
    def redefine_filenames(self):
        self.fn_distn = mktemp(prefix='distn.', suffix='.bin')
        self.fn_trans = mktemp(prefix='trans.', suffix='.bin')
        self.fn_observation = mktemp(prefix='observation.', suffix='.bin')
        self.fn_likelihood = mktemp(prefix='likelihood.', suffix='.bin')
        self.fn_forward = mktemp(prefix='forward.', suffix='.bin')
        self.fn_scaling = mktemp(prefix='scaling.', suffix='.bin')
        self.fn_backward = mktemp(prefix='backward.', suffix='.bin')
        self.fn_posterior = mktemp(prefix='posterior.', suffix='.bin')
        self.input_filenames = set([
                self.fn_distn, self.fn_trans,
                self.fn_observation, self.fn_likelihood])
        self.output_filenames = set([
                self.fn_forward, self.fn_scaling, 
                self.fn_backward, self.fn_posterior])
        self.all_filenames = self.input_filenames | self.output_filenames
    def unlink_all(self):
        for fn in self.all_filenames:
            if os.path.isfile(fn):
                os.unlink(fn)
    def unlink_output(self):
        for fn in self.output_filenames:
            if os.path.isfile(fn):
                os.unlink(fn)

def check_shape(shape, arr, name):
    try:
        arr.reshape(shape)
    except ValueError, v:
        msg_a = '%s error: ' % name
        msg_b = 'the shape %s cannot be reshaped to %s' % (arr.shape, shape)
        raise Exception(msg_a + msg_b)

class TestHmm(unittest.TestCase):

    def help_test_posterior_expectations(self, d, tfm):
        """
        Do some alldisk tests of posterior expectation.
        Test the expected amount of time spent in each state.
        Test the expected count of each transition.
        Test the expectation of each emission symbol from each state.
        @param d: a demo instance
        @param tfm: temp file manager
        """
        # initialize some counts
        nstates = d.get_nstates()
        nalpha = d.get_nalpha()
        # test the state expectations
        e_state_expectations = d.get_expected_state_expectations()
        o_state_expectations = np.zeros(nstates)
        hmm.state_expectations(o_state_expectations, tfm.fn_posterior)
        ok = np.allclose(e_state_expectations, o_state_expectations)
        self.assertTrue(ok, 'state expectations regression testing error')
        # test the transition expectations
        e_transition_expectations = d.get_expected_transition_expectations()
        o_transition_expectations = np.zeros((nstates, nstates))
        hmm.transition_expectations(
                d.get_transitions(), o_transition_expectations,
                tfm.fn_likelihood, tfm.fn_forward, tfm.fn_backward)
        ok = np.allclose(e_transition_expectations, o_transition_expectations)
        self.assertTrue(ok, 'transition expectations regression testing error')
        # test the emission expectations
        e_emission_expectations = d.get_expected_emission_expectations()
        o_emission_expectations = np.zeros((nstates, nalpha))
        hmm.emission_expectations(o_emission_expectations,
                tfm.fn_observation, tfm.fn_posterior)
        ok = np.allclose(e_emission_expectations, o_emission_expectations)
        self.assertTrue(ok, 'emission expectations regression testing error')
        # the state expectations can be derived from emission expectations
        recomputed_state_expectations = o_emission_expectations.sum(axis=1)
        ok = np.allclose(recomputed_state_expectations, e_state_expectations)
        self.assertTrue(ok, 'compatibility regression testing error')

    def help_test_fwdbwd_alldisk(self, d):
        """
        @param d: a demo instance
        """
        tfm = TempFileManager()
        # create the file of observations
        with open(tfm.fn_observation, 'wb') as fout:
            d.get_int8_observations().tofile(fout)
        # create the file of likelihoods
        with open(tfm.fn_likelihood, 'wb') as fout:
            d.get_likelihoods().tofile(fout)
        # run the hmm to create the output files
        hmm.fwdbwd_alldisk(
            d.get_distribution(), d.get_transitions(),
            tfm.fn_likelihood,
            tfm.fn_forward, tfm.fn_scaling, tfm.fn_backward, tfm.fn_posterior)
        # get the expected output
        ex_forward = d.get_expected_forward()
        ex_scaling = d.get_expected_scaling()
        ex_backward = d.get_expected_backward()
        ex_post = d.get_expected_posterior()
        # define the shapes
        sh_forward = ex_forward.shape
        sh_scaling = ex_scaling.shape
        sh_backward = ex_backward.shape
        sh_post = ex_post.shape
        # get the observed output
        ob_forward = np.fromfile(tfm.fn_forward, dtype='d')
        ob_scaling = np.fromfile(tfm.fn_scaling, dtype='d')
        ob_backward = np.fromfile(tfm.fn_backward, dtype='d')
        ob_post = np.fromfile(tfm.fn_posterior, dtype='d')
        # check shape compatibility
        check_shape(sh_forward, ob_forward, 'forward')
        check_shape(sh_scaling, ob_scaling, 'scaling')
        check_shape(sh_backward, ob_backward, 'backward')
        check_shape(sh_post, ob_post, 'posterior')
        # reshape the output arrays
        ob_forward = ob_forward.reshape(sh_forward)
        ob_scaling = ob_scaling.reshape(sh_scaling)
        ob_backward = ob_backward.reshape(sh_backward)
        ob_post = ob_post.reshape(sh_post)
        # compare expected and observed
        for expected, observed, name in (
                (ex_forward, ob_forward, 'forward'),
                (ex_scaling, ob_scaling, 'scaling'),
                (ex_backward, ob_backward, 'backward'),
                (ex_post, ob_post, 'posterior')):
            self.assertTrue(np.allclose(expected, observed),
                    name + ' array mismatch')
        # do more tests before deleting the files
        self.help_test_posterior_expectations(d, tfm)
        # unlink all files
        tfm.unlink_all()

    def help_test_fwdbwd_somedisk(self, d):
        """
        @param d: a demo instance
        """
        tfm = TempFileManager()
        # create the file of likelihoods
        with open(tfm.fn_likelihood, 'wb') as fout:
            d.get_likelihoods().tofile(fout)
        # get the initial distribution and the transition matrix
        np_distn = d.get_distribution()
        np_trans = d.get_transitions()
        # write the posterior distribution as a file
        hmm.fwdbwd_somedisk(np_distn, np_trans,
                tfm.fn_likelihood, tfm.fn_posterior)
        # get the expected posterior data matrix and its shape
        ex_post = d.get_expected_posterior()
        sh_post = ex_post.shape
        # get the observed posterior data matrix
        ob_post = np.fromfile(tfm.fn_posterior, dtype='d')
        # check the compatibility of the shape of the matrix
        check_shape(sh_post, ob_post, 'posterior')
        # reshape the computed posterior matrix
        ob_post = ob_post.reshape(sh_post)
        # compare expected and observed matrices
        self.assertTrue(np.allclose(ex_post, ob_post),
                'posterior array mismatch')
        # unlink all files
        tfm.unlink_all()

    def help_test_fwdbwd_nodisk(self, d):
        """
        @param d: a demo instance
        """
        # compute the posterior
        ob_post = hmm.fwdbwd_nodisk(d.get_distribution(), d.get_transitions(),
                d.get_likelihoods())
        # get the expected posterior
        ex_post = d.get_expected_posterior()
        # compare the computed posterior to the expected posterior
        self.assertTrue(np.allclose(ex_post, ob_post),
                'posterior array mismatch')


    def test_smith_alldisk(self):
        self.help_test_fwdbwd_alldisk(demo.SmithDemo())

    def test_smith_somedisk(self):
        self.help_test_fwdbwd_somedisk(demo.SmithDemo())

    def test_smith_nodisk(self):
        self.help_test_fwdbwd_nodisk(demo.SmithDemo())


    def test_eddy_alldisk(self):
        self.help_test_fwdbwd_alldisk(demo.EddyDemo())

    def test_eddy_somedisk(self):
        self.help_test_fwdbwd_somedisk(demo.EddyDemo())

    def test_eddy_nodisk(self):
        self.help_test_fwdbwd_nodisk(demo.EddyDemo())


class TestHmmMissingData(unittest.TestCase):

    def test_missing(self):
        trans = np.array([
            [0.9, 0.1],
            [0.1, 0.9]])
        emiss = np.array([
            [0.5, 0.5],
            [0.7, 0.3]])
        obs = np.array([127, 127, 127, 127], dtype=np.int8)
        model = estimation.FiniteModel(trans, emiss, obs)
        log_likelihood = model.get_log_likelihood()
        self.assertTrue(log_likelihood == 0)


class TestHmmNoDisk(unittest.TestCase):

    def get_demo(self):
        return None

    def test_forward_nodisk(self):
        d = self.get_demo()
        if not d:
            return
        distn = d.get_distribution()
        trans = d.get_transitions()
        likelihood = d.get_likelihoods()
        forward = np.zeros_like(likelihood)
        scaling = np.zeros(likelihood.shape[0])
        hmm.forward_nodisk(distn, trans, likelihood, forward, scaling)
        expected_forward = d.get_expected_forward()
        expected_scaling = d.get_expected_scaling()
        self.assertTrue(np.allclose(forward, expected_forward))
        self.assertTrue(np.allclose(scaling, expected_scaling))

    def test_backward_nodisk(self):
        d = self.get_demo()
        if not d:
            return
        distn = d.get_distribution()
        trans = d.get_transitions()
        likelihood = d.get_likelihoods()
        forward = np.zeros_like(likelihood)
        scaling = np.zeros(likelihood.shape[0])
        backward = np.zeros_like(likelihood)
        hmm.forward_nodisk(distn, trans, likelihood, forward, scaling)
        hmm.backward_nodisk(distn, trans, likelihood, scaling, backward)
        expected_backward = d.get_expected_backward()
        self.assertTrue(np.allclose(backward, expected_backward))

    def test_posterior_nodisk(self):
        d = self.get_demo()
        if not d:
            return
        distn = d.get_distribution()
        trans = d.get_transitions()
        likelihood = d.get_likelihoods()
        forward = np.zeros_like(likelihood)
        scaling = np.zeros(likelihood.shape[0])
        backward = np.zeros_like(likelihood)
        posterior = np.zeros_like(likelihood)
        hmm.forward_nodisk(distn, trans, likelihood, forward, scaling)
        hmm.backward_nodisk(distn, trans, likelihood, scaling, backward)
        hmm.posterior_nodisk(forward, scaling, backward, posterior)
        expected_posterior = d.get_expected_posterior()
        self.assertTrue(np.allclose(posterior, expected_posterior))

    def test_transition_expectations_nodisk(self):
        d = self.get_demo()
        if not d:
            return
        distn = d.get_distribution()
        trans = d.get_transitions()
        likelihood = d.get_likelihoods()
        forward = np.zeros_like(likelihood)
        scaling = np.zeros(likelihood.shape[0])
        backward = np.zeros_like(likelihood)
        hmm.forward_nodisk(distn, trans, likelihood, forward, scaling)
        hmm.backward_nodisk(distn, trans, likelihood, scaling, backward)
        trans_expect = np.zeros_like(trans)
        hmm.transition_expectations_nodisk(trans, trans_expect,
                likelihood, forward, backward)
        expected_trans_expect = d.get_expected_transition_expectations()
        self.assertTrue(np.allclose(trans_expect, expected_trans_expect))

    def test_emission_expectations_nodisk(self):
        d = self.get_demo()
        if not d:
            return
        distn = d.get_distribution()
        trans = d.get_transitions()
        likelihood = d.get_likelihoods()
        forward = np.zeros_like(likelihood)
        scaling = np.zeros(likelihood.shape[0])
        backward = np.zeros_like(likelihood)
        posterior = np.zeros_like(likelihood)
        hmm.forward_nodisk(distn, trans, likelihood, forward, scaling)
        hmm.backward_nodisk(distn, trans, likelihood, scaling, backward)
        hmm.posterior_nodisk(forward, scaling, backward, posterior)
        emiss_expect = np.zeros((d.get_nstates(), d.get_nalpha()))
        obs = d.get_int8_observations()
        hmm.emission_expectations_nodisk(emiss_expect, obs, posterior)
        expected_emiss_expect = d.get_expected_emission_expectations()
        self.assertTrue(np.allclose(emiss_expect, expected_emiss_expect))


class TestHmmNoDiskSmith(TestHmmNoDisk):
    def get_demo(self):
        return demo.SmithDemo()

class TestHmmNoDiskEddy(TestHmmNoDisk):
    def get_demo(self):
        return demo.EddyDemo()


class TestHmmFiniteModel(unittest.TestCase):

    def get_demo(self):
        return None

    def get_model(self):
        d = self.get_demo()
        if d:
            distn = d.get_distribution()
            trans = d.get_transitions()
            obs = d.get_int8_observations()
            return estimation.FiniteModel(distn, trans, obs)
        else:
            return None

    def test_forward(self):
        d = self.get_demo()
        if not d:
            return
        m = self.get_model()
        e_forward = d.get_expected_forward()
        o_forward = m.get_forward()
        self.assertTrue(np.allclose(e_forward, o_forward))

    def test_backward(self):
        d = self.get_demo()
        if not d:
            return
        m = self.model()
        e_backward = d.get_expected_backward()
        o_backward = m.get_backward()
        self.assertTrue(np.allclose(e_forward, o_forward))

    def test_posterior(self):
        d = self.get_demo()
        if not d:
            return
        m = self.model()
        e_posterior = d.get_expected_posterior()
        o_posterior = m.get_posterior()
        self.assertTrue(np.allclose(e_forward, o_forward))

    def test_transition_expectations(self):
        d = self.get_demo()
        if not d:
            return
        m = self.model()
        e_trans_expect = d.get_expected_transition_expectations()
        o_trans_expect = m.get_trans_expect()
        self.assertTrue(np.allclose(e_trans_expect, o_trans_expect))

    def test_emission_expectations(self):
        d = self.get_demo()
        if not d:
            return
        m = self.model()
        e_emiss_expect = d.get_expected_emission_expectations()
        o_emiss_expect = m.get_emiss_expect()
        self.assertTrue(np.allclose(e_emiss_expect, o_emiss_expect))


class TestHmmFiniteModelSmith(TestHmmNoDisk):
    def get_demo(self):
        return demo.SmithDemo()

class TestHmmFiniteModelEddy(TestHmmNoDisk):
    def get_demo(self):
        return demo.EddyDemo()



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__)
    args = parser.parse_args()
    unittest.main()
