import numpy as np
from numpy.testing import assert_allclose
from menpofit.transform import (DifferentiableR2LogRRBF,
                                DifferentiableR2LogR2RBF)


centers = np.array([[-1.0, -1.0], [-1, 1], [1, -1], [1, 1]])
points = np.array([[-0.4, -1.5], [-0.1, 1.1], [0.1, -2], [2.3, 0.3]])


def test_rbf_r2logr2_d_dl():
    result = DifferentiableR2LogR2RBF(centers).d_dl(points)
    expected = np.array([[[0.60684441, -0.50570368],
                          [3.46630038, -14.44291827],
                          [-5.02037904, -1.79299252],
                          [-8.69498819, -15.52676462]],
                         [[4.77449532, 11.14048909],
                          [1.44278831, 0.16030981],
                          [-5.99792966, 11.45059299],
                          [-2.63747189, 0.23977017]],
                         [[3.94458353, -3.58598503],
                          [7.31140879, -19.94020579],
                          [-2.86798832, -3.18665369],
                          [-5.91012409, -19.70041364]],
                         [[23.31191446, 9.18348145],
                          [22.65025903, -4.8046004],
                          [5.76647684, 5.76647684],
                          [4.62624468, -2.49105483]]])
    assert_allclose(result, expected)


def test_rbf_r2logr_d_dl():
    result = DifferentiableR2LogRRBF(centers).d_dl(points)
    expected = np.array([[[0.30342221, -0.25285184],
                          [1.73315019, -7.22145913],
                          [-2.51018952, -0.89649626],
                          [-4.34749409, -7.76338231]],
                         [[2.38724766, 5.57024454],
                          [0.72139416, 0.08015491],
                          [-2.99896483, 5.72529649],
                          [-1.31873594, 0.11988509]],
                         [[1.97229177, -1.79299252],
                          [3.6557044, -9.9701029],
                          [-1.43399416, -1.59332685],
                          [-2.95506205, -9.85020682]],
                         [[11.65595723, 4.59174073],
                          [11.32512951, -2.4023002],
                          [2.88323842, 2.88323842],
                          [2.31312234, -1.24552741]]])
    assert_allclose(result, expected)
