"""
This package provides the ability to fit models using `Vowpal Wabbit
<https://github.com/JohnLangford/vowpal_wabbit>`_, an open source project
meant for large-scale online learning. This includes implementations of a
variety of models -- linear and logistic regression and others.
Importantly it is straightforward to flexibly use a large number of
unique features, e.g. words in a document, through `hashing
<http://github.com/JohnLangford/vowpal_wabbit/wiki/Feature-Hashing-and-Extraction>`_.

.. sourcecode:: python

    >>> from graphlab import vowpal_wabbit as vw

    # Given an SFrame sf, create a linear model for predicting the 'rating':
    >>> m = vw.create(sf, 'rating')

    # If a column contains text, each space-separated word is used as a
    # unique feature. Often times it is useful to also include bigrams as
    # features. This can be done easily with the ``bigram`` argument:
    >>> m = vw.create(sf, 'rating', bigram=True)

    # To add quadratic terms between 'user' and 'movie' columns:
    >>> m = vw.create(sf, 'rating', quadratic=[('user', 'movie')])


While there are a variety of Python wrappers for Vowpal Wabbit, this one
is directly integrated with our disk-backed :py:class:`~graphlab.SFrame`;
this can make it easier to interactively create new features for Vowpal Wabbit models.
"""

import graphlab as _graphlab
from graphlab.toolkits.model import Model
from graphlab.data_structures.sframe import SFrame as _SFrame
from graphlab.data_structures.sarray import SArray as _SArray
from pandas import DataFrame as _DataFrame
import uuid


class VWModel(Model):
    """
    Wrapper around Vowpal Wabbit.
    """

    def __init__(self, model_proxy):
        self.__proxy__ = model_proxy

    def _get_wrapper(self):
        def model_wrapper(model_proxy):
            return VWModel(model_proxy)

    def list_fields(self):
        return self.__proxy__.keys()

    def get(self, field):
        return self.__proxy__[field]

    def __str__(self):
        return self.__class__.__name__

    def __repr__(self):
        s = "Vowpal Wabbit Model: \n" + \
            "   target: "     + self.__proxy__['target'] + "\n" + \
            "   elapsed (s): "+ str(self.__proxy__['response']['elapsed_time']) + "\n" + \
            "   model_file: " + self.__proxy__['model_file'] + "\n"

        stats = self.get('training_stats')
        if stats['rmse'] is not None:
            s += "\n   training rmse: {a}".format(a=stats['rmse'])
        if stats['accuracy'] is not None:
            s += "\n   training accuracy: {a}".format(a=stats['accuracy'])
        if 'confusion_table' in stats.keys():
            cf = stats['confusion_table']
            s += "\n     # true positives:  {a}".format(a=cf['true_positive'])
            s += "\n     # false negatives: {a}".format(a=cf['false_negative'])
            s += "\n     # true negatives:  {a}".format(a=cf['true_negative'])
            s += "\n     # false positives: {a}".format(a=cf['false_positive'])

        return s

    def summary(self):
        return self.__repr__()

    def training_stats(self):
        return self.__proxy__['training_stats']

    def predict(self, dataset):
        """
        Use the trained :class:`~graphlab.vowpal_wabbit.VWModel` to make
        predictions about the target column that was provided during
        :func:`~graphlab.vowpal_wabbit.create`.

        Parameters
        ----------
        dataset:SFrame
            A data set that has the same columns that were used during training.
            If the target column exists in ``dataset`` it will be ignored while making
            predictions.

        """

        # Decide where to save temporary data set and predictions
        model_file = self.__proxy__['model_file']

        # If the original target column is present, we need its column
        # number in the test dataset.
        target_col_ix = -1
        target = self.__proxy__['target']
        if target in dataset.column_names():
            target_col_ix = dataset.column_names().index(target)

        command_line_args = self.__proxy__['command_line_args']

        opts = {'data': dataset,
                'target_col_ix': target_col_ix,
                'verbose': 0,
                'test_mode': 1}

        # Set VW to be in test mode
        command_line_args += ' -t'

        # Load the mode from file
        if 'initial_regressor' not in command_line_args:
            command_line_args += ' --initial_regressor ' + model_file

        opts['command_line_args'] = command_line_args

        # Compute predictions
        response = _graphlab.toolkits.main.run("vw", opts)

        # Convert predictions to an SArray
        return _SArray(None, _proxy=response['predictions'])


def create(dataset, target,
           quadratic=[], bigram=False, regularization=None, loss_function='squared',
           learning_rate=None, verbose=False, num_passes=1,
           command_line_args=''):
    """
    Learn a large linear model using Vowpal Wabbit.

    Parameters
    ----------
    dataset: SFrame
        A data set. Due to the way Vowpal Wabbit creates features from each entry,
        ':' and '|' characters are not allowed in any columns containing strings.

    target: string
        The name of the column in ``dataset`` that is the prediction target

    quadratic: list of pairs, optional
        This will add `interaction terms <http://en.wikipedia.org/wiki/Interaction_(statistics)>`_
        to a linear model between a pair of columns.
        Quadratic terms add a parameter in the model for the product of two features,
        i.e. if we include an interaction between x_1 and x_2, we can add a parameter b_3

            y_i =  a + b_1 * x_i1 + b_2 * x_i2 + b_3 * x_i1 * x_i2

        Multiple quadratic terms can be added by including multiple pairs, e.g.
        ``quadratic = [('a', 'b'), ('b', 'c')]``
        would add interaction terms between columns names 'a' and 'b' as well as
        terms for interactions between 'b' and 'c'.

        Including ':' as one of the items in the pairs is a shortcut for adding
        quadratic terms for all pairs of features.

        Due to Vowpal Wabbit's implementation, quadratic terms are determined by
        the first letter of the column name.

    loss_function: {squared|hinge|logistic|quantile}
        This defines the `loss function <http://en.wikipedia.org/wiki/Loss_function>`
        used during optimization. Typical choices:

        - real-valued target: `squared error loss <http://en.wikipedia.org/wiki/Mean_squared_error>`_.

        - binary target (-1/1): `logistic <http://en.wikipedia.org/wiki/Logistic_regression>`_. The target column must only contain -1 or 1.

        The `hinge loss <http://en.wikipedia.org/wiki/Hinge_loss>`_ is also used
        for classification, while `quantile loss <http://en.wikipedia.org/wiki/Quantile_regression>`_
        can be good when one aims to predict quantities other than the mean.

    regularization: float, optional
        This defines how strongly you want to keep parameters near zero.

        Specifically it adds a penalty of .5 * lambda * |w|_2^2 to the weight vector w,
        where lambda is the provided regularization value.

    bigram: bool, optional
        Add bigram features. For columns containing the text "my name is bob"
        this will add bigram features for "my name", "name is", "is bob".

    learning_rate: float, optional
        Set the learning rate for online learning.

    verbose: bool, optional
        Print first 10 rows as they are seen by VowpalWabbit.
        This is useful for debugging.

    num_passes: int, optional
        Number of passes to take over the data set.

    command_line_args: string, optional
        Additional arguments to pass to Vowpal Wabbit, just as one would use when
        using VW via the command line.

    Examples
    --------

    .. sourcecode:: python

        # Given an SFrame sf, create a linear model for predicting the 'rating':
        >>> m = graphlab.vw.create(sf, 'rating')

        # To add quadratic terms between 'user' and 'movie' columns:
        >>> m = graphlab.vw.create(sf, 'rating', quadratic=[('user', 'movie')])

    Notes
    -----
    Other desired command line arguments can be provided manuallty through the
    command_line_args keyword argument. See VW documentation for more details:

    http://github.com/JohnLangford/vowpal_wabbit/wiki/Command-line-arguments

    The current implementation of this Python API does not support importance
    weighted learning, and several other Vowpal Wabbit features are not yet supported.
    """

    if not isinstance(dataset, (_DataFrame, _SFrame)):
        raise TypeError('dataset input must be a pandas.DataFrame or SFrame')

    if type(dataset) != _SFrame:
        dataset = _SFrame(dataset)

    assert target in dataset.column_names()

    # Save the model to a temporary file
    model_file = str(uuid.uuid4()) + '.vwmodel'

    # If verbose, print first lines in vowpal wabbit's text format.
    # Currently defaults to 10 lines.
    num_lines_to_print = 0
    if verbose:
        num_lines_to_print = 10

    opts = {'data': dataset,
            'target_col_ix': dataset.column_names().index(target),
            'test_mode': 0,
            'verbose': num_lines_to_print}

    command_line_args = ' --hash all' + \
        ' -d unused_file ' + \
        '--final_regressor ' + model_file

    if loss_function == 'logistic':
        command_line_args += ' --loss_function ' + loss_function
        assert set(dataset[target]) == set([-1, 1])

    if learning_rate is not None:
        command_line_args += ' -l ' + str(learning_rate)

    if regularization is not None:
        command_line_args += ' --l2 ' + str(regularization)

    for (feature_a, feature_b) in quadratic:
        # VW uses first letter to describe namespace
        command_line_args += ' -q ' + feature_a[0] + feature_b[0]

    if bigram:
        command_line_args += ' --ngram 2 '

    opts['command_line_args'] = command_line_args


    print "\n \
         average    since         example     example  current        current  current \n \
         loss       last          counter      weight    label        predict features \
    "

    for i in range(num_passes):
        response = _graphlab.toolkits.main.run("vw", opts)

        if num_passes > 1:
            opts['command_line_args'] += ' --initial_regressor ' + model_file

    opts.update({'response': response,
                 'target': target,
                 'command_line_args': command_line_args,
                 'model_file': model_file})

    m = VWModel(opts)

    y = dataset[target]
    yhat = m.predict(dataset)
    stats = {}
    stats['rmse'] = None
    stats['accuracy'] = None
    if loss_function == 'logistic':
        stats['accuracy'] = _accuracy(y, yhat)
        stats['confusion_table'] = _confusion_table(y, yhat)
    else:
        stats['rmse'] = _rmse(y, yhat)

    m.__proxy__['rmse'] = stats['rmse']
    m.__proxy__['accuracy'] = stats['accuracy']
    m.__proxy__['training_stats'] = stats
    m.__proxy__['__repr__'] = m.__repr__()

    return m

def _rmse(y, yhat):
    """
    Compute root mean squared error between two vectors.
    """
    mse = (y-yhat).apply(lambda x: x**2).mean()
    return mse ** .5

def _accuracy(y, yhat, threshold=0):
    """
    Compute the accuracy, the proportion of correct predictions.
    """
    yhat = yhat.apply(lambda x: -1 if x < threshold else 1)
    return (y==yhat).mean()

def _confusion_table(y, yhat, threshold=0):
    sf = _SFrame()
    sf['y'] = y

    # Threshold the predicted values
    sf['yhat'] = yhat.apply(lambda x: -1 if x < threshold else 1)

    # Count the number of times we have each outcome
    res = {}
    sf['category'] = 2 * sf['yhat'] + sf['y']
    res['true_negative']  = (sf['category'] == -3).sum()
    res['true_positive']  = (sf['category'] ==  3).sum()
    res['false_negative'] = (sf['category'] == -1).sum()
    res['false_positive'] = (sf['category'] ==  1).sum()
    return res


