import numpy as np
import scipy as sp
from scipy import stats as spstats
import pandas as pd
from sklearn.naive_bayes import GaussianNB
from six.moves import range

from numpy.testing import assert_array_equal, assert_array_almost_equal
import numpy.testing as npt
import nose.tools
from nose.tools import assert_equal, assert_almost_equal, raises

from .. import statistical as stat

rs = np.random.RandomState(sum(map(ord, "moss_stats")))

a_norm = rs.randn(100)

a_range = np.arange(101)

datasets = [dict(X=spstats.norm(0, 1).rvs((24, 12)),
                 y=spstats.bernoulli(.5).rvs(24),
                 runs=np.repeat([0, 1], 12)) for i in range(3)]

datasets_3d = [dict(X=spstats.norm(0, 1).rvs((4, 24, 12)),
                    y=spstats.bernoulli(.5).rvs(24),
                    runs=np.repeat([0, 1], 12)) for i in range(3)]


def test_bootstrap():
    """Test that bootstrapping gives the right answer in dumb cases."""
    a_ones = np.ones(10)
    n_boot = 5
    out1 = stat.bootstrap(a_ones, n_boot=n_boot)
    assert_array_equal(out1, np.ones(n_boot))
    out2 = stat.bootstrap(a_ones, n_boot=n_boot, func=np.median)
    assert_array_equal(out2, np.ones(n_boot))


def test_bootstrap_length():
    """Test that we get a bootstrap array of the right shape."""
    out = stat.bootstrap(a_norm)
    assert_equal(len(out), 10000)

    n_boot = 100
    out = stat.bootstrap(a_norm, n_boot=n_boot)
    assert_equal(len(out), n_boot)


def test_bootstrap_range():
    """Test that boostrapping a random array stays within the right range."""
    min, max = a_norm.min(), a_norm.max()
    out = stat.bootstrap(a_norm)
    nose.tools.assert_less(min, out.min())
    nose.tools.assert_greater_equal(max, out.max())


def test_bootstrap_multiarg():
    """Test that bootstrap works with multiple input arrays."""
    x = np.vstack([[1, 10] for i in range(10)])
    y = np.vstack([[5, 5] for i in range(10)])

    test_func = lambda x, y: np.vstack((x, y)).max(axis=0)
    out_actual = stat.bootstrap(x, y, n_boot=2, func=test_func)
    out_wanted = np.array([[5, 10], [5, 10]])
    assert_array_equal(out_actual, out_wanted)


def test_bootstrap_axis():
    """Test axis kwarg to bootstrap function."""
    x = rs.randn(10, 20)
    n_boot = 100
    out_default = stat.bootstrap(x, n_boot=n_boot)
    assert_equal(out_default.shape, (n_boot,))
    out_axis = stat.bootstrap(x, n_boot=n_boot, axis=0)
    assert_equal(out_axis.shape, (n_boot, 20))


def test_smooth_bootstrap():
    """Test smooth bootstrap."""
    x = rs.randn(15)
    n_boot = 100
    out_normal = stat.bootstrap(x, n_boot=n_boot, func=np.median)
    out_smooth = stat.bootstrap(x, n_boot=n_boot,
                                smooth=True, func=np.median)
    assert(np.median(out_normal) in x)
    assert(not np.median(out_smooth) in x)


def test_bootstrap_ols():
    """Test bootstrap of OLS model fit."""
    ols_fit = lambda X, y: np.dot(np.dot(np.linalg.inv(
                                  np.dot(X.T, X)), X.T), y)
    X = np.column_stack((rs.randn(50, 4), np.ones(50)))
    w = [2, 4, 0, 3, 5]
    y_noisy = np.dot(X, w) + rs.randn(50) * 20
    y_lownoise = np.dot(X, w) + rs.randn(50)

    n_boot = 500
    w_boot_noisy = stat.bootstrap(X, y_noisy,
                                  n_boot=n_boot,
                                  func=ols_fit)
    w_boot_lownoise = stat.bootstrap(X, y_lownoise,
                                     n_boot=n_boot,
                                     func=ols_fit)

    assert_equal(w_boot_noisy.shape, (n_boot, 5))
    assert_equal(w_boot_lownoise.shape, (n_boot, 5))
    nose.tools.assert_greater(w_boot_noisy.std(),
                              w_boot_lownoise.std())


@raises(ValueError)
def test_bootstrap_arglength():
    """Test that different length args raise ValueError."""
    stat.bootstrap(range(5), range(10))


@raises(TypeError)
def test_bootstrap_noncallable():
    """Test that we get a TypeError with noncallable statfunc."""
    non_func = "mean"
    stat.bootstrap(a_norm, 100, non_func)


def test_percentiles():
    """Test function to return sequence of percentiles."""
    single_val = 5
    single = stat.percentiles(a_range, single_val)
    assert_equal(single, single_val)

    multi_val = [10, 20]
    multi = stat.percentiles(a_range, multi_val)
    assert_array_equal(multi, multi_val)

    array_val = rs.randint(0, 101, 5).astype(float)
    array = stat.percentiles(a_range, array_val)
    assert_array_almost_equal(array, array_val)


def test_percentiles_acc():
    """Test accuracy of calculation."""
    # First a basic case
    data = np.array([10, 20, 30])
    val = 20
    perc = stat.percentiles(data, 50)
    assert_equal(perc, val)

    # Now test against scoreatpercentile
    percentiles = rs.randint(0, 101, 10)
    out = stat.percentiles(a_norm, percentiles)
    for score, pct in zip(out, percentiles):
        assert_equal(score, sp.stats.scoreatpercentile(a_norm, pct))


def test_percentiles_axis():
    """Test use of axis argument to percentils."""
    data = rs.randn(10, 10)

    # Test against the median with 50th percentile
    median1 = np.median(data)
    out1 = stat.percentiles(data, 50)
    assert_array_almost_equal(median1, out1)

    for axis in range(2):
        median2 = np.median(data, axis=axis)
        out2 = stat.percentiles(data, 50, axis=axis)
        assert_array_almost_equal(median2, out2)

    median3 = np.median(data, axis=0)
    out3 = stat.percentiles(data, [50, 95], axis=0)
    assert_array_almost_equal(median3, out3[0])
    assert_equal(2, len(out3))


def test_ci():
    """Test ci against percentiles."""
    a = rs.randn(100)
    p = stat.percentiles(a, [2.5, 97.5])
    c = stat.ci(a, 95)
    assert_array_equal(p, c)


def test_vector_reject():
    """Test vector rejection function."""
    x = rs.randn(30)
    y = x + rs.randn(30) / 2
    x_ = stat.vector_reject(x, y)
    assert_almost_equal(np.dot(x_, y), 0)


def test_add_constant():
    """Test the add_constant function."""
    a = rs.randn(10, 5)
    wanted = np.column_stack((a, np.ones(10)))
    got = stat.add_constant(a)
    assert_array_equal(wanted, got)


def test_randomize_onesample():
    """Test performance of randomize_onesample."""
    a_zero = rs.normal(0, 1, 50)
    t_zero, p_zero = stat.randomize_onesample(a_zero)
    nose.tools.assert_greater(p_zero, 0.05)

    a_five = rs.normal(5, 1, 50)
    t_five, p_five = stat.randomize_onesample(a_five)
    nose.tools.assert_greater(0.05, p_five)

    t_scipy, p_scipy = sp.stats.ttest_1samp(a_five, 0)
    nose.tools.assert_almost_equal(t_scipy, t_five)


def test_randomize_onesample_range():
    """Make sure that output is bounded between 0 and 1."""
    for i in range(100):
        a = rs.normal(rs.randint(-10, 10),
                      rs.uniform(.5, 3), 100)
        t, p = stat.randomize_onesample(a, 100)
        nose.tools.assert_greater_equal(1, p)
        nose.tools.assert_greater_equal(p, 0)


def test_randomize_onesample_getdist():
    """Test that we can get the null distribution if we ask for it."""
    a = rs.normal(0, 1, 20)
    out = stat.randomize_onesample(a, return_dist=True)
    assert_equal(len(out), 3)


def test_randomize_onesample_iters():
    """Make sure we get the right number of samples."""
    a = rs.normal(0, 1, 20)
    t, p, samples = stat.randomize_onesample(a, return_dist=True)
    assert_equal(len(samples), 10000)
    for n in rs.randint(5, 1e4, 5):
        t, p, samples = stat.randomize_onesample(a, n, return_dist=True)
        assert_equal(len(samples), n)


def test_randomize_onesample_seed():
    """Test that we can seed the random state and get the same distribution."""
    a = rs.normal(0, 1, 20)
    seed = 42
    t_a, p_a, samples_a = stat.randomize_onesample(a, 1000,
                                                   random_seed=seed,
                                                   return_dist=True)
    t_b, t_b, samples_b = stat.randomize_onesample(a, 1000,
                                                   random_seed=seed,
                                                   return_dist=True)
    assert_array_equal(samples_a, samples_b)


def test_randomize_onesample_multitest():
    """Test that randomizing over multiple tests works."""
    a = rs.normal(0, 1, (20, 5))
    t, p = stat.randomize_onesample(a, 1000)
    assert_equal(len(t), 5)
    assert_equal(len(p), 5)

    t, p, dist = stat.randomize_onesample(a, 1000, return_dist=True)
    assert_equal(dist.shape, (5, 1000))


def test_randomize_onesample_correction():
    """Test that maximum based correction (seems to) work."""
    a = rs.normal(0, 1, (100, 10))
    t_un, p_un = stat.randomize_onesample(a, 1000, corrected=False)
    t_corr, p_corr = stat.randomize_onesample(a, 1000, corrected=True)
    assert_array_equal(t_un, t_corr)
    npt.assert_array_less(p_un, p_corr)


def test_randomize_onesample_h0():
    """Test that we can supply a null hypothesis for the group mean."""
    a = rs.normal(4, 1, 100)
    t, p = stat.randomize_onesample(a, 1000, h_0=0)
    assert p < 0.01

    t, p = stat.randomize_onesample(a, 1000, h_0=4)
    assert p > 0.01


def test_randomize_onesample_scalar():
    """Single values returned from randomize_onesample should be scalars."""
    a = rs.randn(40)
    t, p = stat.randomize_onesample(a)
    assert np.isscalar(t)
    assert np.isscalar(p)

    a = rs.randn(40, 3)
    t, p = stat.randomize_onesample(a)
    assert not np.isscalar(t)
    assert not np.isscalar(p)


def test_randomize_corrmat():
    """Test the correctness of the correlation matrix p values."""
    a = rs.randn(30)
    b = a + rs.rand(30) * 3
    c = rs.randn(30)
    d = [a, b, c]

    p_mat, dist = stat.randomize_corrmat(d, tail="upper", corrected=False,
                                         return_dist=True)
    nose.tools.assert_greater(p_mat[2, 0], p_mat[1, 0])

    corrmat = np.corrcoef(d)
    pctile = 100 - spstats.percentileofscore(dist[2, 1], corrmat[2, 1])
    nose.tools.assert_almost_equal(p_mat[2, 1] * 100, pctile)

    d[1] = -a + rs.rand(30)
    p_mat = stat.randomize_corrmat(d)
    nose.tools.assert_greater(0.05, p_mat[1, 0])


def test_randomize_corrmat_dist():
    """Test that the distribution looks right."""
    a = rs.randn(3, 20)
    for n_i in [5, 10]:
        p_mat, dist = stat.randomize_corrmat(a, n_iter=n_i, return_dist=True)
        assert_equal(n_i, dist.shape[-1])

    p_mat, dist = stat.randomize_corrmat(a, n_iter=10000, return_dist=True)

    diag_mean = dist[0, 0].mean()
    assert_equal(diag_mean, 1)

    off_diag_mean = dist[0, 1].mean()
    nose.tools.assert_greater(0.05, off_diag_mean)


def test_randomize_corrmat_correction():
    """Test that FWE correction works."""
    a = rs.randn(3, 20)
    p_mat = stat.randomize_corrmat(a, "upper", False)
    p_mat_corr = stat.randomize_corrmat(a, "upper", True)
    triu = np.triu_indices(3, 1)
    npt.assert_array_less(p_mat[triu], p_mat_corr[triu])


def test_randimoize_corrmat_tails():
    """Test that the tail argument works."""
    a = rs.randn(30)
    b = a + rs.rand(30) * 8
    c = rs.randn(30)
    d = [a, b, c]

    p_mat_b = stat.randomize_corrmat(d, "both", False, random_seed=0)
    p_mat_u = stat.randomize_corrmat(d, "upper", False, random_seed=0)
    p_mat_l = stat.randomize_corrmat(d, "lower", False, random_seed=0)
    assert_equal(p_mat_b[0, 1], p_mat_u[0, 1] * 2)
    assert_equal(p_mat_l[0, 1], 1 - p_mat_u[0, 1])


def test_randomise_corrmat_seed():
    """Test that we can seed the corrmat randomization."""
    a = rs.randn(3, 20)
    _, dist1 = stat.randomize_corrmat(a, random_seed=0, return_dist=True)
    _, dist2 = stat.randomize_corrmat(a, random_seed=0, return_dist=True)
    assert_array_equal(dist1, dist2)


@raises(ValueError)
def test_randomize_corrmat_tail_error():
    """Test that we are strict about tail paramete."""
    a = rs.randn(3, 30)
    stat.randomize_corrmat(a, "hello")


def test_randomize_classifier():
    """Test basic functions of randomize_classifier."""
    data = dict(X=spstats.norm(0, 1).rvs((100, 12)),
                y=spstats.bernoulli(.5).rvs(100),
                runs=np.repeat([0, 1], 50))
    model = GaussianNB()
    p_vals, perm_vals = stat.randomize_classifier(data, model,
                                                  return_dist=True)
    p_min, p_max = p_vals.min(), p_vals.max()
    perm_mean = perm_vals.mean()

    # Test that the p value are well behaved
    nose.tools.assert_greater_equal(1, p_max)
    nose.tools.assert_greater_equal(p_min, 0)

    # Test that the mean is close to chance (this is probabilistic)
    nose.tools.assert_greater(.1, np.abs(perm_mean - 0.5))

    # Test that the distribution looks normal (this is probabilistic)
    val, p = spstats.normaltest(perm_vals)
    nose.tools.assert_greater(p, 0.001)


def test_randomize_classifier_dimension():
    """Test that we can have a time dimension and it's where we expect."""
    data = datasets_3d[0]
    n_perm = 30
    model = GaussianNB()
    p_vals, perm_vals = stat.randomize_classifier(data, model, n_perm,
                                                  return_dist=True)
    nose.tools.assert_equal(len(p_vals), len(data["X"]))
    nose.tools.assert_equal(perm_vals.shape, (n_perm, len(data["X"])))


def test_randomize_classifier_seed():
    """Test that we can give a particular random seed to the permuter."""
    data = datasets[0]
    model = GaussianNB()
    seed = 1
    out_a = stat.randomize_classifier(data, model, random_seed=seed)
    out_b = stat.randomize_classifier(data, model, random_seed=seed)
    assert_array_equal(out_a, out_b)


def test_randomize_classifier_number():
    """Test size of randomize_classifier vectors."""
    data = datasets[0]
    model = GaussianNB()
    for n_iter in rs.randint(10, 250, 5):
        p_vals, perm_dist = stat.randomize_classifier(data, model, n_iter,
                                                      return_dist=True)
        nose.tools.assert_equal(len(perm_dist), n_iter)


def test_transition_probabilities():

    # Test basic
    sched = [0, 1, 0, 1]
    expected = pd.DataFrame([[0, 1], [1, 0]])
    actual = stat.transition_probabilities(sched)
    npt.assert_array_equal(expected, actual)

    sched = [0, 0, 1, 1]
    expected = pd.DataFrame([[.5, .5], [0, 1]])
    actual = stat.transition_probabilities(sched)
    npt.assert_array_equal(expected, actual)

    a = rs.rand(100) < .5
    a = np.where(a, "foo", "bar")
    out = stat.transition_probabilities(a)
    npt.assert_equal(out.columns.tolist(), ["bar", "foo"])
    npt.assert_equal(out.columns, out.index)


def test_gamma_hrf_fit_direct():
    """Very basic test of HRF fitting."""
    hrf = stat.GammaHRF()
    x = np.arange(24)
    y = spstats.gamma(6, 0, .9).pdf(x)
    hrf.fit(x, y)
    npt.assert_allclose(hrf.shape_, 6, atol=1e-6)
    npt.assert_allclose(hrf.scale_, 0.9, atol=1e-6)
    npt.assert_allclose(hrf.baseline_, 0, atol=1e-6)


def test_gamma_hrf_predict():
    """Test predictions of HRF model."""
    hrf = stat.GammaHRF()
    x = np.arange(24)
    y = spstats.gamma(6, 0, .9).pdf(x)
    y += rs.normal(0, .01, 24)
    y_hat = hrf.fit(x, y).predict(x)
    npt.assert_allclose(y, y_hat, atol=.1)


def test_gamma_hrf_peak():
    """Test calculation of HRF peak time."""
    hrf = stat.GammaHRF()
    x = np.arange(24)
    y = spstats.gamma(6, 0, .9).pdf(x)
    hrf.fit(x, y)
    peak_wanted = 5 * .9
    peak_observed = hrf.peak_time_
    npt.assert_allclose(peak_wanted, peak_observed, atol=.2)


def test_gamma_r2():
    """Test that R2 is better with less noise."""
    hrf = stat.GammaHRF()
    x = np.arange(24)
    y = spstats.gamma(6, 0, .9).pdf(x)
    y_low = y + rs.normal(0, .001, 24)
    y_high = y + rs.normal(0, .05, 24)
    r2_low = hrf.fit(x, y).r2_score(x, y_low)
    r2_high = hrf.fit(x, y).r2_score(x, y_high)
    nose.tools.assert_less(r2_high, r2_low)


def test_gamma_hrf_bounds():
    """Test that we can supply bounds to the HRF optimzation."""
    bounds = dict(shape=(3, 5.75), scale=(1, 1.5))
    hrf = stat.GammaHRF(shape=5.5, scale=1.1, bounds=bounds)
    x = np.arange(24)
    y = spstats.gamma(6, 0, .9).pdf(x)
    y += rs.normal(0, .01, 24)
    hrf.fit(x, y)
    nose.tools.assert_less(hrf.shape_, 5.75)
    nose.tools.assert_less_equal(1, hrf.scale_)
