"""
Methods for creating a linear regression model.
"""
import graphlab.connect as _mt
import graphlab as _graphlab
from graphlab.toolkits.regression.regression import RegressionModel
from graphlab.data_structures.sframe import SFrame as _SFrame
from pandas import DataFrame as _DataFrame
from graphlab import vowpal_wabbit as _vw

# TODO: List of todo's for this file
#------------------------------------------------------------------------------
# 1. Better input/output messages in the comments.
# 2. Error handling of input/output parameters
# 3. Default option parser.
# 4. APIs for std errors, p-values etc.


DEFAULT_SOLVER_OPTIONS = {
'convergence_threshold': 1e-2,
'step_size': 1.0,
'lbfgs_memory_level': 3,
'mini_batch_size': 1000,
'max_iters': 1,
'quadratic':[],
'bigram': False,
'regularization':None}


def create(dataset,
           response,
           predictors=None,
           solver='newton',
           verbose=True,
           plot=False,
           solver_options=DEFAULT_SOLVER_OPTIONS):
    """
    Create an :class:`~graphlab.linear_regression.LinearRegressionModel` that
    predicts a scalar 'response' variable as a linear function of one or more
    'predictor' variables.


    .. note:: Linear regression automatically adds a constant term for the bias.

    .. note:: Linear regression supports categorical variables (provided as string inputs) if the solver is set to 'vw'.


    Parameters
    ----------
    dataset : pandas.DataFrame/SFrame
        The dataset to use for training the model.

    response: string
        The column name, in the dataset, that corresponds to the response in
        the model.

    predictors: list of strings (default None)
        A list of column names that correspond to the predictors in the model.
        A value of 'None' indicates that all columns in the dataset (except the
        response) column must be used as a predictor.

    solver: string
        Name of solver to be used to solve the regression.


        +-------------------------------------------------------------+
        |                      Available solvers                      |
        +=====================+=======================================+
        |'newton' (default)   | Directly solve the normal-equations.  |
        +---------------------+---------------------------------------+
        |'lbfgs'              | Limited memory BFGS.                  |
        +---------------------+---------------------------------------+
        |'vw'                 | Vowpal Wabbit.                        |
        +---------------------+---------------------------------------+
        |'fista'              | Accelerated gradient descent.         |
        +---------------------+---------------------------------------+
        |'sgd'                | Stochastic gradient.                  |
        +---------------------+---------------------------------------+
        |'gd'                 | Gradient Descent.                     |
        +---------------------+---------------------------------------+


      .. note:: Newton method is the best choice for all datasets with a small number of features.

    verbose: bool, optional
        If True, print progress updates.

    plot: bool, optional
        If True, display the progress plot.

    solver_options: dict
        Solver options. The options and their default values are provided
        below.

        .. note:: Default values are used for options that are not explicitly provided.

        +------------------------+----------------+-------------------------------------------------+
        |                                Common options                                             |
        +========================+================+=================================================+
        |      Option name       | Default value  |      Description                                |
        +------------------------+----------------+-------------------------------------------------+
        | 'convergence_threshold'|     1e-2       |  Convergence criterion.                         |
        +------------------------+----------------+-------------------------------------------------+
        | 'max_iters'            |     1          |  Max iterations of the method.                  |
        +------------------------+----------------+-------------------------------------------------+
        | 'step_size'            |     1.0        |  Initial step size (a.k.a learning rate).       |
        +------------------------+----------------+-------------------------------------------------+
        | 'lbfgs_memory_level'   |     3          |  Number of previous updates to store in L-BFGS. |
        +------------------------+----------------+-------------------------------------------------+
        | 'mini_batch_size'      |     1000       |  Number of examples in an SGD mini-batch        |
        +------------------------+----------------+-------------------------------------------------+
        | 'auto_tuning'          |     True       |  Turn on/off SGD step-size auto-tuner.          |
        +------------------------+----------------+-------------------------------------------------+

        See :class:`~graphlab.vowpal_wabbit.create` for more details on the following options.

        +------------------------+----------------+-------------------------------------------------+
        |                                VW Options                                                 |
        +========================+================+=================================================+
        |      Option name       | Default value  |      Description                                |
        +------------------------+----------------+-------------------------------------------------+
        | 'quadratic'            |     []         |  List of pairs for quadratic-interaction terms. |
        +------------------------+----------------+-------------------------------------------------+
        | 'regularization'       |     0.0        |  L-2 regularization penalty.                    |
        +------------------------+----------------+-------------------------------------------------+
        | 'bigram'               |     False      |  Add bigram features (Boolean).                 |
        +------------------------+----------------+-------------------------------------------------+


    Returns
    -------
    out : LinearRegressionModel
        A trained LinearRegressionModel.

    Examples
    --------
    If given an :class:`~graphlab.SFrame` ``sf`` with a list of columns
    [``pred_1`` ... ``pred_K``] denoting predictors and a response column
    ``response``, then we can create
    a :class:`~graphlab.logistic_regression.LogisticRegressionModel` as
    follows:

    >>> m = linear_regression.create(sf, 'response', ['pred_1' ... 'pred_K'])

    You can change th default options as follows.

    >>> m = linear_regression.create(sf, 'response', ['pred_1' ... 'pred_K'], solver='newton', solver_options = {'max_iters': 10})

    With this model object one can make predictions for data in ``sf``:

    >>> pred = m.predict(sf)

    The model can be saved to disk as follows:

    >>> m.save(filename)

    For more, see the documentation for
    :class:`~graphlab.linear_regression.LinearRegressionModel`.

    """

    _mt._get_metric_tracker().track('toolkit.regression.linear_regression.create')

    if not isinstance(dataset, (_DataFrame, _SFrame)):
        raise TypeError('Dataset input must be an SFrame.')

    if type(dataset) != _SFrame:
        dataset = _SFrame(dataset)

    if plot is True:
        print "Plot functionality not yet implemented."
        plot = False

    model_name = "regression_linear_regression"
    solver = solver.lower()

    # Make sure all keys in the dictionary are lower
    _solver_options = {}
    for k, v in solver_options.iteritems():
        _solver_options[k] = v


    # This should throw an error if the columns do not exist
    response_sframe = dataset.select_columns([response])

    # Check predictors
    if predictors is None:
      predictors = dataset.column_names()
      predictors.remove(response)

    # Check that everything is a string
    if not hasattr(predictors, '__iter__'):
        raise TypeError("Input 'predictors' must be an iterable")
    if not all([isinstance(x, str) for x in predictors]):
        raise TypeError("Invalid predictor key type: must be str")

    predictors_sframe = dataset.select_columns(predictors)

    # Note: Options must be a flat dictionary.
    opts = {'response'    : response_sframe,
            'predictors'  : predictors_sframe,
            'model_name'  : model_name,
            'solver'      : solver}
    opts.update(_solver_options)

    # Model checking
    ret = _graphlab.toolkits.main.run("regression_train_init", opts)
    opts.update(ret)
    model = ret['model']

    # Switching point for vw
    if(solver == 'vw' or solver == 'vowpal-wabbit'):

        # Make the SFrame into VW format
        required_columns = predictors + [response]
        sf = dataset.select_columns(required_columns)

        # Append options
        vw_options = {}
        for key in ['quadratic', 'bigram', 'regularization', 'max_iters',
            'step_size']:
            if key in solver_options:
                vw_options[key] = solver_options[key]
            else:
                vw_options[key] = DEFAULT_SOLVER_OPTIONS[key]

        # Call VW
        vw_model = _vw.create(sf, response,
                             quadratic=vw_options['quadratic'],
                             bigram = vw_options['bigram'],
                             regularization = vw_options['regularization'],
                             loss_function ='squared',
                             learning_rate = vw_options['step_size'],
                             verbose = verbose,
                             num_passes = vw_options['max_iters'])


        model_file = vw_model.__proxy__['model_file']
        target = vw_model.__proxy__['target']

        ## Write the VW model to our model
        args = {'key'       : "model_file",
                'value'     : model_file,
                'model_name': model_name,
                'model'     : model}
        ret = _graphlab.toolkits.main.run("regression_append_model", args)
        model = ret["model"]

        args = {'key'       : "__repr__",
                'value'     : vw_model.__proxy__['__repr__'],
                'model_name': model_name,
                'model'     : model}
        ret = _graphlab.toolkits.main.run("regression_append_model", args)
        model = ret["model"]

        args = {'key'       : "rmse",
                'value'     : vw_model.__proxy__['rmse'],
                'model_name': model_name,
                'model'     : model}
        ret = _graphlab.toolkits.main.run("regression_append_model", args)
        model = ret["model"]

        args = {'key'       : "accuracy",
                'value'     : vw_model.__proxy__['accuracy'],
                'model_name': model_name,
                'model'     : model}
        ret = _graphlab.toolkits.main.run("regression_append_model", args)
        model = ret["model"]

        args = {'key'       : "target",
                'value'     : target,
                'model_name': model_name,
                'model'     : model}
        ret = _graphlab.toolkits.main.run("regression_append_model", args)
        model = ret["model"]

        args = {'key'       : "trained",
                'value'     : 1,
                'model_name': model_name,
                'model'     : model}
        ret = _graphlab.toolkits.main.run("regression_append_model", args)
        model = ret["model"]

    # Call all our solvers!
    else:
        ret = _graphlab.toolkits.main.run("regression_train_init", opts)
        ret = _graphlab.toolkits.main.run("regression_train", opts,
                verbose, plot)

    # Return either a VW solver or this solver
    model = ret['model']
    return LinearRegressionModel(model)


class LinearRegressionModel(RegressionModel):
    """

    Linear regression is an approach for modeling a scalar response :math:`y`
    as a linear function of one or more explanatory/predictor variables denoted
    :math:`X`.

    """
    def __init__(self, model_proxy):
        self.__proxy__ = model_proxy
        self.__name__ = "regression_linear_regression"
    def _get_wrapper(self):
        def model_wrapper(model_proxy):
            return LinearRegressionModel(model_proxy)
        return model_wrapper
