"""
This package contains methods for evaluating the quality of predictive machine
learning models.
"""
import graphlab.connect as _mt
import graphlab as _graphlab
from graphlab.data_structures.sframe import SFrame as _SFrame
from graphlab.data_structures.sarray import SArray as _SArray

def max_error(targets, predictions):
    r"""
    Compute the maximum absolute deviation between two SArrays.

    Parameters
    ----------
    targets : SArray[float or int]
        An Sarray of ground truth target values.

    predictions : SArray[float or int]
        The prediction that corresponds to each target value.
        This vector must have the same length as ``targets``.

    Returns
    -------
    out : float
        The maximum absolute deviation error between the two SArrays.

    See Also
    --------
    rmse

    Notes
    -----
    The maximum absolute deviation between two vectors, x and y, is defined as:

    .. math::

        \textrm{max error} = \max_{i \in 1,\ldots,N} \|x_i - y_i\|

    Examples
    --------
    >>> targets = graphlab.SArray([3.14, 0.1, 50, -2.5])
    >>> predictions = graphlab.SArray([3.1, 0.5, 50.3, -5])
    >>> graphlab.evaluation.max_error(targets, predictions)
    2.5
    """

    _mt._get_metric_tracker().track('evaluation.max_error')

    assert type(targets) == _SArray, 'Input targets much be an SArray'
    assert type(predictions) == _SArray, 'Input predictions much be an SArray'
    assert targets.size() == predictions.size(), 'Input targets and predictions must have the same size'

    val = (targets - predictions).apply(lambda x: abs(x)).max()
    return val

def rmse(targets, predictions):
    r"""
    Compute the root mean squared error between two SArrays.

    Parameters
    ----------
    targets : SArray[float or int]
        An Sarray of ground truth target values.

    predictions : SArray[float or int]
        The prediction that corresponds to each target value.
        This vector must have the same length as ``targets``.

    Returns
    -------
    out : float
        The RMSE between the two SArrays.

    See Also
    --------
    max_error

    Notes
    -----
    The root mean squared error between two vectors, x and y, is defined as:

    .. math::

        RMSE = \sqrt{\frac{1}{N} \sum_{i=1}^N (x_i - y_i)^2}

    References
    ----------
    - `Wikipedia - root-mean-square deviation
      <http://en.wikipedia.org/wiki/Root-mean-square_deviation>`_

    Examples
    --------
    >>> targets = graphlab.SArray([3.14, 0.1, 50, -2.5])
    >>> predictions = graphlab.SArray([3.1, 0.5, 50.3, -5])
    >>> graphlab.evaluation.rmse(targets, predictions)
    1.2749117616525465
    """

    _mt._get_metric_tracker().track('evaluation.rmse')

    assert type(targets) == _SArray, 'Input targets much be an SArray'
    assert type(predictions) == _SArray, 'Input predictions much be an SArray'
    assert targets.size() == predictions.size(), 'Input targets and predictions must have the same size'

    opts = {'targets': targets,
            'predictions': predictions}
    response = _graphlab.toolkits.main.run("evaluation_rmse", opts)
    return response["rmse"]

def confusion_matrix(targets, predictions, threshold=0.5):
    r"""
    Compute the confusion matrix for classification predictions.
    The matrix contains the following counts:

    - true positive: target is 1 and prediction is greater than threshold
    - false positive: target is 1 and prediction is less than or equal to threshold
    - true negative: target is 0 and prediction is less than or equal to threshold
    - false negative: target is 0 and prediction is greater than threshold

    Parameters
    ----------
    targets : SArray[int]
        Ground truth class labels. Must contain only 0s and 1s.

    predictions : SArray[float]
        The prediction that corresponds to each target value.
        This vector must have the same length as ``targets``.

    threshold : float, optional
        The classification threshold for prediction. Predictions greater than
        this value are counted as class 1. Predictions less than or equal to
        this value are counted as class 0.

    Returns
    -------
    out : dict
        A dictionary containing counts for 'true_positive', 'false_positive',
        'true_negative', 'false_negative'.

    See Also
    --------
    accuracy

    Examples
    --------
    >>> targets = graphlab.SArray([0, 1, 1, 0])
    >>> predictions = graphlab.SArray([0.1, 0.35, 0.7, 0.99])
    >>> graphlab.evaluation.confusion_matrix(targets, predictions, threshold=0.7)
    {'false_negative': 2,
     'false_positive': 1,
     'true_negative': 1,
     'true_positive': 0}
    """

    _mt._get_metric_tracker().track('evaluation.confusion_matrix')

    assert type(targets) == _SArray, 'Input targets much be an SArray'
    assert type(predictions) == _SArray, 'Input predictions much be an SArray'
    assert targets.size() == predictions.size(), 'Input targets and predictions must have the same size'

    opts = {'targets': targets,
            'predictions': predictions,
            'threshold': threshold}
    response = _graphlab.toolkits.main.run("evaluation_confusion_matrix", opts)
    return response

def accuracy(targets, predictions, threshold=0.5):
    r"""
    Compute the proportion of correct predictions.

    Predictions that are equal to `threshold` are counted as class 0.
    
    Parameters
    ----------
    targets : SArray[int]
        Ground truth class labels. Must contain only 0s and 1s.

    predictions : SArray[float]
        The prediction that corresponds to each target value.
        This vector must have the same length as ``targets``.

    threshold : float, optional
        The classification threshold for prediction. Predictions greater than
        this value are counted as class 1. Predictions less than or equal to
        this value are counted as class 0.

    Returns
    -------
    out : float
        The ratio of the number of correct classifications and the total number
        of data points.

    See Also
    --------
    confusion_matrix

    Examples
    --------
    >>> targets = graphlab.SArray([0, 1, 1, 0])
    >>> predictions = graphlab.SArray([0.1, 0.35, 0.7, 0.99])
    >>> graphlab.evaluation.accuracy(targets, predictions, threshold=0.7)
    0.25
    """

    _mt._get_metric_tracker().track('evaluation.accuracy')

    cm = confusion_matrix(targets, predictions, threshold)
    return float(cm['true_positive'] + cm['true_negative']) / targets.size()
