"""
This package provides the ability to fit models using `Vowpal Wabbit
<https://github.com/JohnLangford/vowpal_wabbit>`_, an open source project
meant for large-scale online learning. This includes implementations of a
variety of models -- linear and logistic regression and others.
Importantly it is straightforward to flexibly use a large number of
unique features, e.g. words in a document, through `hashing
<http://github.com/JohnLangford/vowpal_wabbit/wiki/Feature-Hashing-and-Extraction>`_.

.. sourcecode:: python

    >>> from graphlab import vowpal_wabbit as vw

    # Given an SFrame sf, create a linear model for predicting the 'rating':
    >>> m = vw.create(sf, 'rating')

    # If a column contains text, each space-separated word is used as a
    # unique feature. Often times it is useful to also include bigrams as
    # features. This can be done easily with the ``bigram`` argument:
    >>> m = vw.create(sf, 'rating', bigram=True)

    # To add quadratic terms between 'user' and 'movie' columns:
    >>> m = vw.create(sf, 'rating', quadratic=[('user', 'movie')])


While there are a variety of Python wrappers for Vowpal Wabbit, this one
is directly integrated with our disk-backed :py:class:`~graphlab.SFrame`;
this can make it easier to interactively create new features for Vowpal Wabbit models.
"""

import graphlab as _graphlab
import graphlab.connect as _mt
from graphlab.toolkits.model import Model
from graphlab.data_structures.sframe import SFrame as _SFrame
from graphlab.data_structures.sarray import SArray as _SArray
from graphlab.deps import pandas as _pandas, HAS_PANDAS as _HAS_PANDAS
import time

class VWModel(Model):
    """
    Wrapper around Vowpal Wabbit.
    """

    def __init__(self, model_proxy):
        self.__proxy__ = model_proxy

    def _get_wrapper(self):
        def model_wrapper(model_proxy):
            return VWModel(model_proxy)
        return model_wrapper

    def list_fields(self):
        _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.list_fields')
        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run("vw_get_model_description", opts)
        fields = response
        return fields.keys()

    def get(self, field):
        _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.get_type')
        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run("vw_get_model_description", opts)
        fields = response
        return fields[field]

    def _set(self, field, value):
        opts = {'model': self.__proxy__,
                'key': field,
                'value': value}
        response = _graphlab.toolkits.main.run("vw_set_model_description", opts)
        return VWModel(response['model'])

    def __str__(self):

        opts = {'model': self.__proxy__}
        fields = _graphlab.toolkits.main.run("vw_get_model_description", opts)

        s = "Vowpal Wabbit Model: \n"
        s += "\n   target column:  " + str(fields['target_column'])
        s += "\n   loss function:  " + str(fields['loss_function'])
        s += "\n   step size:      " + str(fields['step_size'])
        s += "\n   L1 penalty:     " + str(fields['l1_penalty'])
        s += "\n   L2 penalty:     " + str(fields['l2_penalty'])
        s += "\n   verbose:        " + str(fields['verbose'])
        s += "\n   bits:           " + str(fields['num_bits'])
        s += "\n   max iterations: " + str(fields['max_iterations'])

        s += '\n'
        for k in ['train_rmse', 'train_accuracy', 'train_stats_elapsed_time']:
            if k in fields.keys():
                s += "\n   {k}: {v}".format(k=k, v=fields[k])

        if 'confusion_table' in fields.keys():
            cf = self.get('confusion_table')
            s += "\n     # true positives:  {a}".format(a=cf['true_positive'])
            s += "\n     # false negatives: {a}".format(a=cf['false_negative'])
            s += "\n     # true negatives:  {a}".format(a=cf['true_negative'])
            s += "\n     # false positives: {a}".format(a=cf['false_positive'])

        return s

    def __repr__(self):
        return self.__str__()

    def summary(self):
        _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.summary')
        print self.__str__()

    def training_stats(self):
        """
        Get information about model creation, e.g. time elapsed during model
        fitting, data loading, and more.

        Returns
        -------
        out : dict
            Statistics about model training, e.g. runtime.

        """
        _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.training_stats')
        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run("vw_get_training_stats", opts)
        return response

    def predict(self, dataset):
        """
        Use the trained :class:`~graphlab.vowpal_wabbit.VWModel` to make
        predictions about the target column that was provided during
        :func:`~graphlab.vowpal_wabbit.create`.

        Parameters
        ----------
        dataset : SFrame
            A data set that has the same columns that were used during training.
            If the target column exists in ``dataset`` it will be ignored while making
            predictions.

        Returns
        -------
        out : SArray
            Predicted target value for each example (i.e. row) in the dataset.
        """
        _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.predict')

        opts = {'model': self.__proxy__,
                'data': dataset}
        response = _graphlab.toolkits.main.run("vw_predict", opts)

        # Convert predictions to an SArray
        return _SArray(None, _proxy=response['predictions'])

    def evaluate(self, dataset):
        """
        Evaluate the model by making predictions of target values and comparing
        these to actual values. Currently, this method only supports vw models
        trained with ``squared`` or ``logistic`` loss.

        If the model is trained with ``squared`` loss, the evaluation metrics are
        root-mean-squared error (RMSE) and the absolute value of the maximum error
        between the actual and predicted values.

        Let :math:`y` and :math:`\hat{y}` denote vectors of length :math:`N`
        (number of examples) with actual and predicted values. The RMSE is
        defined as:

        .. math::

            RMSE = \sqrt{\frac{1}{N} \sum_{i=1}^N (\widehat{y}_i - y_i)^2}.

        The max-error is defined as

        .. math::

            max-error = \max_{i=1}^N \|\widehat{y}_i - y_i\| .

        If the model is trained with ``logistic`` loss, then the model is evaluated
        as a classifier with a decision threshold of 0. The metrics are classification
        accuracy and confusion table.  Classification accuracy is the fraction of
        examples whose predicted and actual classes match. The confusion table contains
        the cross-tabulation of actual and predicted classes for the target variable.

        Parameters
        ----------
        dataset : SFrame
            Dataset in the same format as the SFrame used to train the model.
            The columns names and types must be the same as that used in
            training, including the target column.

        Returns
        -------
        out : dict
            Dictionary of evaluation results. For ``squared`` loss, the dictionary
            keys are *rmse* and *max_error*.  For ``logistic`` loss, the dictionary
            keys are *accuracy* and *confusion_table*.

        References
        ----------
        - `Wikipedia - confusion matrix
          <http://en.wikipedia.org/wiki/Confusion_matrix>`_
        - `Wikipedia - root-mean-square deviation
          <http://en.wikipedia.org/wiki/Root-mean-square_deviation>`_
        """
        _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.evaluate')

        target_column = self.get('target_column')
        if target_column not in dataset.column_names():
            raise ToolkitError, \
                "Input dataset must contain a target column for " \
                "evaluation of prediction quality."

        targets = dataset[target_column]
        predictions = self.predict(dataset)

        loss = self.get('loss_function')
        if loss == 'squared':
            rmse = _graphlab.evaluation.rmse(targets, predictions)
            max_error = _graphlab.evaluation.max_error(targets, predictions)
            return {'rmse' : rmse,
                    'max_error' : max_error}
        elif loss == 'logistic':
            targets = targets.astype(int)
            # Convert target to 0/1
            targets = targets.apply(lambda x : 0 if x == -1 else x)
            accuracy = _graphlab.evaluation.accuracy(targets, predictions, threshold=0)
            confusion_table = _graphlab.evaluation.confusion_matrix(targets, predictions, threshold=0)
            return {'accuracy' : accuracy,
                    'confusion_table' : confusion_table }
        else:
            raise ToolkitError, "VW evaluate currently only supports models trained with squared or logistic loss."


def create(dataset, target_column,
           loss_function='squared', quadratic=[],
           l1_penalty=0.0, l2_penalty=0.0,
           bigram=False,
           step_size=0.5, num_bits=18, verbose=False, max_iterations=1,
           command_line_args=''):
    """
    Learn a large linear model using Vowpal Wabbit.

    Parameters
    ----------
    dataset: SFrame
        A data set. Due to the way Vowpal Wabbit creates features from each entry,
        ':' and '|' characters are not allowed in any columns containing strings.
        Each row of the dataset is translated into a string and passed to Vowpal
        Wabbit. Currently, the upper bound on the size of the string is 1MB.
        Based on the type of the SArray column, the valued are passed in the following
        ways.

        - *integer* or *float*  : the value is passed directly to VW.
        - *str*                 : the name of the column is used as the namespace,
                                  followed by the entire string.
        - *dict*                : the name of the column is used as the namespace,
                                  and each key-value pair is a feature. The keys of
                                  the dictionary must be string or numeric and the
                                  values must be numeric (integer or float).
        - *array*               : the name of the column is used as the namespace,
                                  the index of the array element is used as the name
                                  of the feature, and only numeric elements in the
                                  array are passed onto VW.
        - *list* (recursive type): the name of the column is used as the namespace,
                                   the index of the list element is used as the name
                                   of the feature, and currently only numeric elements
                                   (integer or float) are passed onto VW.

        Here are more details on `VW input format
        <https://github.com/JohnLangford/vowpal_wabbit/wiki/Input-format>`_.


    target_column: string
        The name of the column in ``dataset`` that is the prediction target.
        This column must have a numeric type.

    quadratic: list of pairs, optional
        This will add `interaction terms <http://en.wikipedia.org/wiki/Interaction_(statistics)>`_
        to a linear model between a pair of columns.
        Quadratic terms add a parameter in the model for the product of two features,
        i.e. if we include an interaction between x_1 and x_2, we can add a parameter b_3

            y_i =  a + b_1 * x_i1 + b_2 * x_i2 + b_3 * x_i1 * x_i2

        Multiple quadratic terms can be added by including multiple pairs, e.g.
        ``quadratic = [('a', 'b'), ('b', 'c')]``
        would add interaction terms between columns names 'a' and 'b' as well as
        terms for interactions between 'b' and 'c'.

        Including ':' as one of the items in the pairs is a shortcut for adding
        quadratic terms for all pairs of features.

        Due to Vowpal Wabbit's implementation, quadratic terms are determined by
        the first letter of the column name.

    loss_function: {squared|hinge|logistic|quantile}
        This defines the `loss function <http://en.wikipedia.org/wiki/Loss_function>`
        used during optimization. Typical choices:

        - real-valued target: `squared error loss <http://en.wikipedia.org/wiki/Mean_squared_error>`_.

        - binary target (-1/1): `logistic <http://en.wikipedia.org/wiki/Logistic_regression>`_. The target column must only contain -1 or 1.

        The `hinge loss <http://en.wikipedia.org/wiki/Hinge_loss>`_ is also used
        for classification, while `quantile loss <http://en.wikipedia.org/wiki/Quantile_regression>`_
        can be good when one aims to predict quantities other than the mean.

    l1_penalty: float, optional
        This defines how strongly you want to keep parameters to be zero.

    l2_penalty: float, optional
        This defines how strongly you want to keep parameters near zero.

        Specifically it adds a penalty of .5 * lambda * |w|_2^2 to the weight vector w,
        where lambda is the provided regularization value.

    bigram: bool, optional
        Add bigram features. For columns containing the text "my name is bob"
        this will add bigram features for "my name", "name is", "is bob".

    step_size: float, optional
        Set the learning rate for online learning.

    verbose: bool, optional
        Print first 10 rows as they are seen by VowpalWabbit.
        This is useful for debugging.

    max_iterations: int, optional
        Number of passes to take over the data set.

    command_line_args: string, optional
        Additional arguments to pass to Vowpal Wabbit, just as one would use when
        using VW via the command line.

    Examples
    --------

    .. sourcecode:: python

        # Given an SFrame sf, create a linear model for predicting the 'rating':
        >>> m = graphlab.vw.create(sf, 'rating')

        # To add quadratic terms between 'user' and 'movie' columns:
        >>> m = graphlab.vw.create(sf, 'rating', quadratic=[('user', 'movie')])

    Notes
    -----
    Other desired command line arguments can be provided manuallty through the
    command_line_args keyword argument. See VW documentation for more details:

    http://github.com/JohnLangford/vowpal_wabbit/wiki/Command-line-arguments

    The current implementation of this Python API does not support importance
    weighted learning, and several other Vowpal Wabbit features are not yet supported.
    """

    _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.create')

    if not (isinstance(dataset, _SFrame) or _HAS_PANDAS and isinstance(dataset, _pandas.DataFrame)):
        raise TypeError('dataset input must be a pandas.DataFrame or SFrame')

    if type(dataset) != _SFrame:
        dataset = _SFrame(dataset)

    assert target_column in dataset.column_names(), "No target_column provided."

    quadratic_command = ''
    for (feature_a, feature_b) in quadratic:
        # VW uses first letter to describe namespace
        quadratic_command += ' -q ' + feature_a[0] + feature_b[0]

    opts = {'verbose': verbose,
            'target_column': target_column,
            'loss_function': loss_function,
            'quadratic': quadratic_command,
            'step_size': step_size,
            'l1_penalty': l1_penalty,
            'l2_penalty': l2_penalty,
            'num_bits' : num_bits,
            'max_iterations': max_iterations,
            'bigram': bigram,
            'extra_command_line_args': command_line_args}

    # Initialize the model with basic parameters
    response = _graphlab.toolkits.main.run("vw_init", opts)
    m = VWModel(response['model'])

    # Train the model on the given data set and retrieve predictions
    opts = {'model': m.__proxy__,
            'data': dataset}
    response = _graphlab.toolkits.main.run("vw_train", opts)
    m = VWModel(response['model'])

    yhat = _SArray(None, _proxy=response['predictions'])

    # Evaluate model
    start_time = time.time()
    y = dataset[target_column]

    if loss_function == 'logistic':
        is_one_or_neg_one = y.apply(lambda x: x == 1 or x == -1)
        if not all(is_one_or_neg_one):
            raise TypeError('When using `logistic` as a loss function, the target column must contain only 1\'s and -1\'s.')
        y = y.apply(lambda x: int(x*.5 + .5))
        m = m._set('train_accuracy', _graphlab.evaluation.accuracy(y, yhat))
        m = m._set('confusion_table', _graphlab.evaluation.confusion_matrix(y, yhat))
    else:
        m = m._set('train_rmse', _graphlab.evaluation.rmse(y, yhat))
    m = m._set('train_stats_elapsed_time', time.time() - start_time)

    return m


