#!/usr/bin/env python

"""Do regression testing.
"""

from tempfile import mktemp
import os
import unittest

import argparse
import numpy as np

from hmmus import demo
from hmmus import hmm

class TempFileManager:
    def __init__(self):
        self.redefine_filenames()
    def redefine_filenames(self):
        self.fn_distn = mktemp(prefix='distn.', suffix='.bin')
        self.fn_trans = mktemp(prefix='trans.', suffix='.bin')
        self.fn_likelihood = mktemp(prefix='likelihood.', suffix='.bin')
        self.fn_forward = mktemp(prefix='forward.', suffix='.bin')
        self.fn_scaling = mktemp(prefix='scaling.', suffix='.bin')
        self.fn_backward = mktemp(prefix='backward.', suffix='.bin')
        self.fn_posterior = mktemp(prefix='posterior.', suffix='.bin')
        self.input_filenames = set([
                self.fn_distn, self.fn_trans, self.fn_likelihood])
        self.output_filenames = set([
                self.fn_forward, self.fn_scaling, 
                self.fn_backward, self.fn_posterior])
        self.all_filenames = self.input_filenames | self.output_filenames
    def unlink_all(self):
        for fn in self.all_filenames:
            if os.path.isfile(fn):
                os.unlink(fn)
    def unlink_output(self):
        for fn in self.output_filenames:
            if os.path.isfile(fn):
                os.unlink(fn)

def check_shape(shape, arr, name):
    try:
        arr.reshape(shape)
    except ValueError, v:
        msg_a = '%s error: ' % name
        msg_b = 'the shape %s cannot be reshaped to %s' % (arr.shape, shape)
        raise Exception(msg_a + msg_b)

class TestHmm(unittest.TestCase):

    def help_test_fwdbwd_alldisk(self, d):
        """
        @param d: a demo instance
        """
        tfm = TempFileManager()
        # create the file of likelihoods
        with open(tfm.fn_likelihood, 'wb') as fout:
            d.get_likelihoods().tofile(fout)
        # run the hmm to create the output files
        hmm.fwdbwd_alldisk(
            d.get_distribution(), d.get_transitions(),
            tfm.fn_likelihood,
            tfm.fn_forward, tfm.fn_scaling, tfm.fn_backward, tfm.fn_posterior)
        # get the expected output
        ex_forward = d.get_expected_forward()
        ex_scaling = d.get_expected_scaling()
        ex_backward = d.get_expected_backward()
        ex_post = d.get_expected_posterior()
        # define the shapes
        sh_forward = ex_forward.shape
        sh_scaling = ex_scaling.shape
        sh_backward = ex_backward.shape
        sh_post = ex_post.shape
        # get the observed output
        ob_forward = np.fromfile(tfm.fn_forward, dtype='d')
        ob_scaling = np.fromfile(tfm.fn_scaling, dtype='d')
        ob_backward = np.fromfile(tfm.fn_backward, dtype='d')
        ob_post = np.fromfile(tfm.fn_posterior, dtype='d')
        # check shape compatibility
        check_shape(sh_forward, ob_forward, 'forward')
        check_shape(sh_scaling, ob_scaling, 'scaling')
        check_shape(sh_backward, ob_backward, 'backward')
        check_shape(sh_post, ob_post, 'posterior')
        # reshape the output arrays
        ob_forward = ob_forward.reshape(sh_forward)
        ob_scaling = ob_scaling.reshape(sh_scaling)
        ob_backward = ob_backward.reshape(sh_backward)
        ob_post = ob_post.reshape(sh_post)
        # compare expected and observed
        for expected, observed, name in (
                (ex_forward, ob_forward, 'forward'),
                (ex_scaling, ob_scaling, 'scaling'),
                (ex_backward, ob_backward, 'backward'),
                (ex_post, ob_post, 'posterior')):
            self.assertTrue(np.allclose(expected, observed),
                    name + ' array mismatch')
        # unlink all files
        tfm.unlink_all()

    def help_test_fwdbwd_somedisk(self, d):
        """
        @param d: a demo instance
        """
        tfm = TempFileManager()
        # create the file of likelihoods
        with open(tfm.fn_likelihood, 'wb') as fout:
            d.get_likelihoods().tofile(fout)
        # get the initial distribution and the transition matrix
        np_distn = d.get_distribution()
        np_trans = d.get_transitions()
        # write the posterior distribution as a file
        hmm.fwdbwd_somedisk(np_distn, np_trans,
                tfm.fn_likelihood, tfm.fn_posterior)
        # get the expected posterior data matrix and its shape
        ex_post = d.get_expected_posterior()
        sh_post = ex_post.shape
        # get the observed posterior data matrix
        ob_post = np.fromfile(tfm.fn_posterior, dtype='d')
        # check the compatibility of the shape of the matrix
        check_shape(sh_post, ob_post, 'posterior')
        # reshape the computed posterior matrix
        ob_post = ob_post.reshape(sh_post)
        # compare expected and observed matrices
        self.assertTrue(np.allclose(ex_post, ob_post),
                'posterior array mismatch')
        # unlink all files
        tfm.unlink_all()

    def help_test_fwdbwd_nodisk(self, d):
        """
        @param d: a demo instance
        """
        # compute the posterior
        ob_post = hmm.fwdbwd_nodisk(d.get_distribution(), d.get_transitions(),
                d.get_likelihoods())
        # get the expected posterior
        ex_post = d.get_expected_posterior()
        # compare the computed posterior to the expected posterior
        self.assertTrue(np.allclose(ex_post, ob_post),
                'posterior array mismatch')


    def test_smith_alldisk(self):
        self.help_test_fwdbwd_alldisk(demo.SmithDemo())

    def test_smith_somedisk(self):
        self.help_test_fwdbwd_somedisk(demo.SmithDemo())

    def test_smith_nodisk(self):
        self.help_test_fwdbwd_nodisk(demo.SmithDemo())


    def test_eddy_alldisk(self):
        self.help_test_fwdbwd_alldisk(demo.EddyDemo())

    def test_eddy_somedisk(self):
        self.help_test_fwdbwd_somedisk(demo.EddyDemo())

    def test_eddy_nodisk(self):
        self.help_test_fwdbwd_nodisk(demo.EddyDemo())


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__)
    args = parser.parse_args()
    unittest.main()
