"""
Methods for creating and using an SVM model.
"""
import graphlab.connect as _mt
import graphlab as _graphlab
from graphlab.data_structures.sframe import SArray as _SArray
from graphlab.data_structures.sframe import SFrame as _SFrame
from graphlab.deps import pandas as _pandas, HAS_PANDAS as _HAS_PANDAS
from graphlab.toolkits.supervised_learning import SupervisedLearningModel

DEFAULT_SOLVER_OPTIONS = {
'convergence_threshold': 1e-2,
'max_iterations': 10,
'lbfgs_memory_level': 10
}

DEFAULT_HYPER_PARAMETER_RANGE = {
    'penalty' : [0.001, 0.01, 0.1, 1.0, 10.0, 100.0]
}

def create(dataset, target, features=None, penalty=1.0, solver='auto',
    feature_rescaling=True, solver_options=None, verbose=True):
    """
    Create a :class:`~graphlab.svm.SVMModel` to predict the class of a binary
    target variable based on a model of which side of a hyperplane the example
    falls on. In addition to standard numeric and categorical types, features
    can also be extracted automatically from list- or dictionary-type SFrame
    columns.

    This loss function for the SVM model is the sum of an L1 mis-classification
    loss (multiplied by the 'penalty' term) and a L2-norm on the weight vectors.

    Parameters
    ----------
    dataset : SFrame 
        Dataset for training the model.

    target : string
        Name of the column containing the target variable. The values in this
        column must be 0 or 1, of integer type.

    features : list[string], optional
        Names of the columns containing features. 'None' (the default) indicates
        that all columns except the target variable should be used as features.

        The features are columns in the input SFrame that can be of the
        following types:

        - *Numeric*: values of numeric type integer or float.

        - *Categorical*: values of type string.

        - *Array*: list of numeric (integer or float) values. Each list element
          is treated as a separate feature in the model.

        - *Dictionary*: key-value pairs with numeric (integer or float) values
          Each key of a dictionary is treated as a separate feature and the
          value in the dictionary corresponds to the value of the feature. 
          Dictionaries are ideal for representing sparse data.
          
        Columns of type *list* are not supported. Convert them to array in
        case all entries in the list are of numeric types and separate them 
        out into different columns if they are of mixed type.

    penalty : float, optional
        Penalty term on the mis-classification loss of the model. The smaller
        this weight, the more the model coefficients shrink toward 0.  The
        smaller the penalty, the lower is the emphasis placed on misclassified
        examples, and the classifier would spend more time maximizing the
        margin for correctly classified examples. The default value is 1.0;
        this parameter must be set to a value of at least 1e-10.


    solver : string, optional
        Name of the solver to be used to solve the problem. See the
        references for more detail on each solver. Available solvers are:

        - *auto (default)*: automatically chooses the best solver (from the ones
         listed below) for the data and model parameters.
        - *lbfgs*: lLimited memory BFGS (``lbfgs``) is a robust solver for wide 
        datasets(i.e datasets with many coefficients).  

        The solvers are all automatically tuned and the default options should 
        function well. See the solver options guide for setting additional 
        parameters for each of the solvers.

    feature_rescaling: bool, default = true

        Feature rescaling is an important pre-processing step that ensures
        that all features are on the same scale. An L2-norm rescaling is
        performed to make sure that all features are of the same norm. Categorical
        features are also rescaled by rescaling the dummy variables that 
        are used to represent them. The coefficients are returned in original
        scale of the problem.

    solver_options : dict, optional
        Solver options. The options and their default values are as follows:

        +-----------------------+---------+-----------------------------------------+
        |      Option           | Default |        Description                      |
        +=======================+=========+=========================================+
        | convergence_threshold |    1e-2 | Desired training accuracy               |
        +-----------------------+---------+-----------------------------------------+
        | max_iterations        |     10  | Max number of solver iterations         |
        +-----------------------+---------+-----------------------------------------+
        | lbfgs_memory_level    |     10  | Memory used by lbfgs (lbfgs only)       |
        +-----------------------+---------+-----------------------------------------+

        convergence_threshold: 

        Convergence is tested using variation in the training objective. The 
        variation in the training objective is calculated using the difference 
        between the objective values between two steps. Consider reducing this 
        below the default value (0.01) for a more accurately trained model. 
        Beware of overfitting (i.e a model that works well only on the training 
        data) if this parameter is set to a very low value.

        max_iterations: 
        
        The maximum number of allowed passes through the data. More passes over 
        the data can result in a more accurately trained model. Consider 
        increasing this (the default value is 10) if the training accuracy is 
        low and the *Grad-Norm* in the display is large.

        lbfgs_memory_level:
        
        The L-BFGS algorithm keeps track of gradient information from the 
        previous ``lbfgs_memory_level`` iterations. The storage requirement for 
        each of these gradients is the ``num_coefficients`` in the problem. 
        Increasing the ``lbfgs_memory_level ``can help improve the quality of 
        the model trained. Setting this to more than ``max_iterations`` has the 
        same effect as setting it to ``max_iterations``.

    verbose : bool, optional
        If True, print progress updates.

    Returns
    -------
    out : SVMModel
        A trained model of type
        :class:`~graphlab.svm.SVMModel`.
    
    See Also
    --------
    SVMModel

    Notes
    -----
    - Categorical variables are encoded by creating dummy variables. For
      a variable with :math:`K` categories, the encoding creates :math:`K-1`
      dummy variables, while the first category encountered in the data is used
      as the baseline.

    - For prediction and evaluation of SVM models with sparse dictionary
      inputs, new keys/columns that were not seen during training are silently
      ignored.

    - The penalty parameter is analogous to the 'C' term in the C-SVM. See the
      reference on training SVMs for more details.

    - Any 'None' values in the data will result in an error being thrown.

    - A constant term of '1' is automatically added for the model intercept to
      model the bias term.

    - Note that the hinge loss is approximated by the scaled logistic loss
      function. (See user guide for details)


    References
    ----------
    - `Wikipedia - Support Vector Machines
      <http://en.wikipedia.org/wiki/svm>`_

    - Zhang et al. - Modified Logistic Regression: An Approximation to
      SVM and its Applications in Large-Scale Text Categorization (ICML 2003)


    Examples
    --------

    Given an :class:`~graphlab.SFrame` ``sf``, a list of feature columns
    [``feature_1`` ... ``feature_K``], and a target column ``target`` with 0 and
    1 values, create a
    :class:`~graphlab.svm.SVMModel` as follows:

    >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')
    >>> data['is_expensive'] = data['price'] > 30000
    >>> model = graphlab.svm.create(data, 'is_expensive')
    """

    if verbose:
      print "Preparing the data..."

    model_name = "supervised_learning_svm"
    _mt._get_metric_tracker().track(
        'toolkit.classification.svm.create')

    # Step 1: Setup model and options to make sure they are sensible
    # -------------------------------------------------------------------------
    # Check data matrix type and convert to SFrame
    if not (isinstance(dataset, _SFrame) or (_HAS_PANDAS and
      isinstance(dataset, _pandas.DataFrame))):
        raise TypeError('Dataset input must be an SFrame on a pandas DataFrame.')
    if type(dataset) != _SFrame:
        dataset = _SFrame(dataset)

    # Make the solver name and all solver options lower case.
    # Also keeps a separate copy of solver_options to prevent changes.
    solver = solver.lower()
    if solver_options is not None:
        _solver_options = {k.lower(): v for k,v in solver_options.items()}
    else:
        _solver_options = {}

    # Make sure target contains 0/1
    sf_target = dataset.select_columns([target])
    target_type = sf_target[target].dtype()
    if not target_type == int:
        raise TypeError("Target column must be type int.")
    zero_or_one = (sf_target[target] == 0) + (sf_target[target] == 1)
    if not all(zero_or_one):
        raise ValueError("The target column for SVM classification" + \
            " must contain only 0 and 1 values.")

    # Select the features.
    if features is None:
        features = dataset.column_names()
        features.remove(target)
    if not hasattr(features, '__iter__'):
        raise TypeError("Input features must be an iterable.")
    if not all([isinstance(x, str) for x in features]):
        raise TypeError("Invalid type: Feature names must be string.")
    sf_features = dataset.select_columns(features)


    # Init the model
    # -------------------------------------------------------------------------
    # Set up options dictionary and initialize the model
    opts = {}
    opts.update(_solver_options)
    opts.update({'target'     : sf_target,
                'features'    : sf_features,
                'model_name'  : model_name,
                'solver'      : solver,
                'feature_rescaling'  : feature_rescaling,
                'penalty'  : penalty})

    # Call the C++ init function
    ret = _graphlab.toolkits.main.run("supervised_learning_train_init", opts)
    opts.update(ret)

    # Train the model
    # -------------------------------------------------------------------------
    try:
        ret = _graphlab.toolkits.main.run("supervised_learning_train", opts, verbose)
    except:
        raise ValueError("Model failed to train.")

    model_proxy = ret['model']
    model = SVMModel(model_proxy)
    return model


class SVMModel(SupervisedLearningModel):
    """
    Support Vector Machines can be used to predict binary target variable using
    several feature variables.  An instance of this model can be created using
    :func:`graphlab.svm.create`.  Do not construct the model directly.
    Additional details about the model construction (along with code samples)
    are documented with the :func:`graphlab.svm.create` function.

    The :py:class:`~graphlab.svm.SVMModel` model predicts a binary target
    variable given one or more feature variables. In an SVM model, the examples
    are represented as points in space, mapped so that the examples from the
    two classes being classified are divided by linear separator. 
    
    Given a set of features :math:`x_i`, and a label :math:`y_i \in \{0,1\}`,
    SVM minimizes the loss function:

        .. math::
          f_i(\\theta) =  \max(1 - \\theta^T x, 0)

    An intercept term is added by appending a column of 1's to the features.
    Regularization is often required to prevent over fitting by penalizing
    models with extreme parameter values. The composite objective being
    optimized for is the following:

        .. math::
           \min_{\\theta} \sum_{i = 1}^{n} f_i(\\theta) + \lambda ||\\theta||^{2}_{2} 

    where :math:`\lambda` is the ``penalty`` parameter.
    

    Examples
    --------

    .. sourcecode:: python

        # Load the data (From an S3 bucket)
        >>> import graphlab as gl
        >>> data =  gl.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')

        # Make sure the target is binary 0/1
        >>> data['is_expensive'] = data['price'] > 30000
        
        # Make a logistic regression model
        >>> model = gl.svm.create(data, target='is_expensive'
                                        , features=['bath', 'bedroom', 'size'])

        # Extract the coefficients
        >>> coefficients = model['coefficients']     # an SFrame
        
        # Make predictions (as margins, or class)
        >>> predictions = model.predict(data)    # Predicts 0/1
        >>> predictions = model.predict(data, output_type='margin')       

        # Evaluate the model 
        >>> results = model.evaluate(data)               # a dictionary

    See Also
    --------
    create
    
    """
    def __init__(self, model_proxy):
        '''__init__(self)'''
        self.__proxy__ = model_proxy
        self.__name__ = "supervised_learning_svm"

    def _get_wrapper(self):
        def model_wrapper(model_proxy):
            return SVMModel(model_proxy)
        return model_wrapper

    def __str__(self):
        """
        Return a string description of the model to the ``print`` method.

        Returns
        -------
        out : string
            A description of the model.
        """
        return self.__repr__()

    def __repr__(self):
        """
        Print a string description of the model, when the model name is entered
        in the terminal.
        """

        solver = self.get('solver')
        width = 20
        key_str = "{:<{}}: {}"
        model_fields = [
            ("Penalty", 'penalty'),
            ("Examples", 'num_examples'),
            ("Features", 'num_features'),
            ("Coefficients", 'num_coefficients')]

        solver_fields = [
            ("Solver", 'solver'),
            ("Solver iterations", 'train_iters'),
            ("Solver status", 'solver_status'),
            ("Training time (sec)", 'train_time')]

        train_fields = [
            ("Train Loss", 'train_loss')]

        ret = []
        ret.append(key_str.format("Class", width, self.__class__.__name__))

        for tranche_fields in [model_fields, solver_fields, train_fields]:
            for k, v in tranche_fields:
                value = self.get(v)
                if isinstance(value, float):
                    try:
                        value = round(value, 4)
                    except:
                        pass
                ret.append(key_str.format(k, width, value))
            ret.append("")
        return '\n'.join(ret)

    def get(self, field):
        """
        Return the value of a given field. The list of all queryable fields is
        detailed below, and can be obtained programmatically with the
        :func:`~graphlab.svm.SVMModel.list_fields` method.


        +-----------------------+----------------------------------------------+
        |      Field            | Description                                  |
        +=======================+==============================================+
        | coefficients          | classification coefficients                  |
        +-----------------------+----------------------------------------------+
        | convergence_threshold | Desired solver accuracy                      |
        +-----------------------+----------------------------------------------+
        | feature_rescaling     | Bool indicating l2-rescaling of features     |
        +-----------------------+---------+------------------------------------+
        | features              | Feature column names                         |
        +-----------------------+----------------------------------------------+
        | lbfgs_memory_level    | Number of updates to store (lbfgs only)      |
        +-----------------------+----------------------------------------------+
        | max_iterations        | Maximum number of solver iterations          |
        +-----------------------+----------------------------------------------+
        | num_coefficients      | Number of coefficients in the model          |
        +-----------------------+----------------------------------------------+
        | num_examples          | Number of examples used for training         |
        +-----------------------+----------------------------------------------+
        | num_features          | Number of dataset columns used for training  |
        +-----------------------+----------------------------------------------+
        | penalty               | Misclassification penalty term               |
        +-----------------------+----------------------------------------------+
        | solver                | Type of solver                               |
        +-----------------------+----------------------------------------------+
        | solver_status         | Solver status after training                 |
        +-----------------------+----------------------------------------------+
        | target                | Target column name                           |
        +-----------------------+----------------------------------------------+
        | train_iters           | Number of solver iterations                  |
        +-----------------------+----------------------------------------------+
        | train_loss            | Maximized Log-likelihood                     |
        +-----------------------+----------------------------------------------+
        | train_time            | Training time (excludes preprocessing)       |
        +-----------------------+----------------------------------------------+
        | trained               | Indicates whether model is trained           |
        +-----------------------+----------------------------------------------+

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out
            Value of the requested fields.
        
        See Also
        --------
        list_fields

        Examples
        --------

        >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')

        >>> data['is_expensive'] = data['price'] > 30000
        >>> model = graphlab.svm.create(data, 
                                  target='is_expensive',
                                  features=['bath', 'bedroom', 'size'])
        >>> print model['num_features']
        3
        >>> print model.get('num_features')       # equivalent to previous line
        3
        """

        _mt._get_metric_tracker().track(
            'toolkit.classification.svm.get')

        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'field': field}
        response = _graphlab.toolkits.main.run('supervised_learning_get_value', opts)

        # Coefficients returns a unity SFrame. Cast to an SFrame.
        # --------------------------------------------------------------------
        if field == 'coefficients':
            return _SFrame(None, _proxy=response['value'])
        else:
            return response['value']

    def summary(self):
        """
        Display a summary of the SVMModel.

        Examples
        --------
        >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')

        >>> data['is_expensive'] = data['price'] > 30000
        >>> model = graphlab.svm.create(data,
                                  target='is_expensive',
                                  features=['bath', 'bedroom', 'size'])
        >>> model.summary()
        """

        _mt._get_metric_tracker().track(
            'toolkit.classification.svm.summary')
        coefs = self.get('coefficients')
        top_coefs = coefs.topk('Coefficient', k=6)
        top_coefs = top_coefs[top_coefs['Coefficient'] > 0]

        bottom_coefs = coefs.topk('Coefficient', k=5, reverse=True)
        bottom_coefs = bottom_coefs[bottom_coefs['Coefficient'] < 0]

        print ""
        print "                    Model summary                       "
        print "--------------------------------------------------------"
        print self.__repr__()

        print "             Strongest positive coefficients            "
        print "--------------------------------------------------------"
        if len(top_coefs) > 0:
            print _SFrame(top_coefs)
        else:
            print "[No positive coefficients]"

        print "             Strongest negative coefficients            "
        print "--------------------------------------------------------"
        if len(bottom_coefs) > 0:
            print _SFrame(bottom_coefs)
        else:
            print "[No negative coefficients]"

    def get_default_options(self):
        """
        Return a dictionary with the model's default options.

        Returns
        -------
        out : dict
            Dictionary with model default options.

        See Also
        --------
        get_current_options, list_fields, get

        Examples
        --------
        >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')

        >>> data['is_expensive'] = data['price'] > 30000
        >>> model = graphlab.svm.create(data,
                                  target='is_expensive',
                                  features=['bath', 'bedroom', 'size'])
        >>> default_options = model.get_default_options()

        """
        _mt._get_metric_tracker().track(
                'toolkit.classification.svm.get_default_options')
        return super(SVMModel, self).get_default_options()

    def get_current_options(self):
        """
        Return a dictionary with the options used to define and train the model.

        Returns
        -------
        out : dict
            Dictionary with options used to define and train the model.

        See Also
        --------
        get_default_options, list_fields, get

        Examples
        --------
        >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')

        >>> data['is_expensive'] = data['price'] > 30000
        >>> model = graphlab.svm.create(data,
                                  target='is_expensive',
                                  features=['bath', 'bedroom', 'size'])
        >>> current_options = model.get_current_options()
        """
        _mt._get_metric_tracker().track('toolkit.classification.svm.get_current_options')
        return super(SVMModel, self).get_current_options()

    def predict(self, dataset, output_type='class', missing_value_action='impute'):
        """
        Return predictions for ``dataset``, using the trained logistic
        regression model. Predictions can be generated as class labels (0 or
        1), or margins (i.e. the distance of the observations from the hyperplane
        separating the classes). By default, the predict method returns class
        labels.

        For each new example in ``dataset``, the margin---also known as the
        linear predictor---is the inner product of the example and the model
        coefficients plus the intercept term. Predicted classes are obtained by
        thresholding the margins at 0.

        Parameters
        ----------
        dataset : SFrame
            Dataset of new observations. Must include columns with the same
            names as the features used for model training, but does not require
            a target column. Additional columns are ignored.

        output_type : {'margin', 'class'}, optional
            Form of the predictions.

        missing_value_action: str, optional
            Action to perform when missing values are encountered. This can be 
            one of:

            - 'impute': Proceed with evaluation by filling in the missing
                        values with the mean of the training data. Missing
                        values are also imputed if an entire column of data is
                        missing during evaluation.
            - 'error' : Do not proceed with prediction and terminate with 
                        an error message.

        Returns
        -------
        out : SArray
            An SArray with model predictions.

        See Also
        ----------
        create, evaluate

        Examples
        ----------
        >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')

        >>> data['is_expensive'] = data['price'] > 30000
        >>> model = graphlab.svm.create(data,
                                  target='is_expensive',
                                  features=['bath', 'bedroom', 'size'])

        >>> class_predictions = model.predict(data)
        >>> margin_predictions = model.predict(data, output_type='margin')

        """

        if output_type not in ['class', 'margin']:
            raise ValueError("Output type '{}' is not supported.".format(output_type) + \
                             " Please select 'class'.")
        _mt._get_metric_tracker().track('toolkit.classification.svm.predict')
        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'dataset': dataset,
                'missing_value_action': missing_value_action,
                'output_type': output_type}

        ## Compute the predictions
        init_opts = _graphlab.toolkits.main.run('supervised_learning_predict_init', opts)
        opts.update(init_opts)
        target = _graphlab.toolkits.main.run('supervised_learning_predict', opts)
        return _SArray(None, _proxy=target['predicted'])


    def evaluate(self, dataset, missing_value_action = 'impute'):
        """
        Evaluate the model by making predictions of target values and comparing
        these to actual values.

        Two metrics are used to evaluate SVM. The confusion table contains the
        cross-tabulation of actual and predicted classes for the target
        variable. Classification accuracy is the fraction of examples whose
        predicted and actual classes match.

        Parameters
        ----------
        dataset : SFrame
            Dataset of new observations. Must include columns with the same
            names as the target and features used for model training. Additional
            columns are ignored.

        missing_value_action: str, optional
            Action to perform when missing values are encountered. This can be 
            one of:

            - 'impute': Proceed with evaluation by filling in the missing
                        values with the mean of the training data. Missing
                        values are also imputed if an entire column of data is
                        missing during evaluation.
            - 'error' : Do not proceed with evaluation and terminate with 
                        an error message.

        Returns
        -------
        out : dict
            Dictionary of evaluation results. The dictionary keys are *accuracy*
            and *confusion_table*.

        See Also
        ----------
        create, predict

        Examples
        ----------
        >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')

        >>> data['is_expensive'] = data['price'] > 30000
        >>> model = graphlab.svm.create(data,
                                  target='is_expensive',
                                  features=['bath', 'bedroom', 'size'])

        >>> results = model.evaluate(data)
        >>> print results['accuracy']
        """

        _mt._get_metric_tracker().track(
            'toolkit.classification.svm.evaluate')
        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'missing_value_action': missing_value_action,
                'dataset': dataset}

        ## Compute the predictions
        init_opts = _graphlab.toolkits.main.run('supervised_learning_evaluate_init', opts)
        opts.update(init_opts)
        response = _graphlab.toolkits.main.run('supervised_learning_evaluate', opts)
        opts.update(response)
        results = _graphlab.toolkits.main.run("supervised_learning_get_evaluate_stats", opts)

        # Return the accuracy and confusion tables
        ret = {}
        ret['accuracy'] = results['accuracy']
        keys = ['true_positive','true_negative','false_positive','false_negative']
        confusion_table = {k: results[k] for k in keys}
        ret['confusion_table'] = confusion_table
        return ret


    def list_fields(self):
        """
        List the fields stored in the model, including data, model, and training
        options. Each field can be queried with the ``get`` method. Note: the
        list of queryable fields is different when the solver is ``vw``.

        Returns
        -------
        out : list
            List of fields queryable with the ``get`` method.
        
        See Also
        --------
        get

        Examples
        --------
        >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')

        >>> data['is_expensive'] = data['price'] > 30000
        >>> model = graphlab.svm.create(data,
                                  target='is_expensive',
                                  features=['bath', 'bedroom', 'size'])

        >>> model.list_fields()
        """

        _mt._get_metric_tracker().track(
            'toolkit.classification.svm.list_fields')
        return super(SVMModel, self).list_fields()
