"""
Methods for creating a logistic regression model.
"""
import graphlab.connect as _mt
import graphlab as _graphlab
from graphlab.toolkits.regression.regression import RegressionModel
from graphlab.data_structures.sframe import SFrame as _SFrame
from pandas import DataFrame as _DataFrame
from graphlab import vowpal_wabbit as _vw

# TODO: List of todo's for this file
#------------------------------------------------------------------------------
# 1. Better input/output messages in the comments.
# 2. Error handling of input/output parameters
# 3. Default option parser.
# 4. APIs for std errors, p-values etc.
# 5. Change None to depict default step size.


DEFAULT_SOLVER_OPTIONS = {
'convergence_threshold': 1e-2,
'step_size': None,
'max_iters': 1,
'quadratic':[],
'bigram': False,
'regularization':None}


def create(dataset,
           response,
           predictors=None,
           solver='vw',
           verbose=True,
           plot=False,
           solver_options=DEFAULT_SOLVER_OPTIONS):
    """
    Create an :class:`~graphlab.logistic_regression.LogisticRegressionModel`
    that predicts a binary 'response' variable as using one ore more
    'predictor' variables. The probabilities describing the outcome of the
    binary variable is modeled as a logistic function of the predictor
    variables.

    .. note:: The response column for logistic regression must contain values with only +1 or -1.

    .. note:: Logistic regression automatically adds a constant term for the bias.

    .. note:: Logistic regression supports categorical variables (provided as string inputs) if the solver is set to 'vw'.


    Parameters
    ----------
    dataset : pandas.DataFrame/SFrame
        The dataset to use for training the model.

    response: string
        The column name, in the dataset, that corresponds to the response in
        the model.

        The response column must be +1 or -1.

    predictors : list of strings (default None)
        A list of column names that correspond to the predictors in the model.
        A value of 'None' indicates that all columns in the dataset (except the
        response) column must be used as a predictor.

    solver: string
        Name of the solver to be used to solve the regression.

        +-------------------------------------------------------------+
        |                      Available solvers                      |
        +=====================+=======================================+
        |'vw' (default)       | Vowpal Wabbit.                        |
        +---------------------+---------------------------------------+

    verbose : bool, optional
        If True, print progress updates.

    plot : bool, optional
        If True, display the progress plot.

    solver_options: dict
        Solver options. The options and their default values are provided
        below.

        .. note:: Default values are used for options that are not explicitly provided.

        +------------------------+----------------+-------------------------------------------------+
        |                                Common options                                             |
        +========================+================+=================================================+
        |      Option name       | Default value  |      Description                                |
        +------------------------+----------------+-------------------------------------------------+
        | 'max_iters'            |     100        |  Max iterations of the method.                  |
        +------------------------+----------------+-------------------------------------------------+
        | 'step_size'            |     None       |  Initial step size (a.k.a learning rate).       |
        +------------------------+----------------+-------------------------------------------------+

        See :class:`~graphlab.vowpal_wabbit.create` for more details on options.

        +------------------------+----------------+-------------------------------------------------+
        |                                VW Options                                                 |
        +========================+================+=================================================+
        |      Option name       | Default value  |      Description                                |
        +------------------------+----------------+-------------------------------------------------+
        | 'quadratic'            |     []         |  List of pairs for quadratic-interaction terms. |
        +------------------------+----------------+-------------------------------------------------+
        | 'regularization'       |     0.0        |  L-2 regularization penalty.                    |
        +------------------------+----------------+-------------------------------------------------+
        | 'bigram'               |     False      |  Add bigram features (Boolean).                 |
        +------------------------+----------------+-------------------------------------------------+


    Returns
    -------
    out : LogisticRegressionModel
        A trained LogisticRegressionModel.

    Examples
    --------
    If given an :class:`~graphlab.SFrame` ``sf`` with a list of columns
    [``pred_1`` ... ``pred_K``] denoting predictors and a response column
    ``response`` with +1/-1 as values, then we can create
    a :class:`~graphlab.logistic_regression.LogisticRegressionModel` as
    follows:

    >>> m = logistic_regression.create(sf, 'response', ['pred_1' ... 'pred_K'])

    With this model object one can make predictions for data in ``sf``:

    >>> pred = m.predict(sf)

    The model can be saved to disk as follows:

    >>> m.save(filename)

    For more, see the documentation for
    :class:`~graphlab.logistic_regression.LogisticRegressionModel`.

    """
     
    _mt._get_metric_tracker().track('toolkit.regression.logistic_regression.create')

    if not isinstance(dataset, (_DataFrame, _SFrame)):
        raise TypeError('Dataset input must be an SFrame.')

    if type(dataset) != _SFrame:
        dataset = _SFrame(dataset)

    if plot is True:
        print "Plot functionality not yet implemented."
        plot = False

    model_name = "regression_logistic_regression"
    solver = solver.lower()


    # Deals with the conrner case that VW's step size is set to None
    step_size_none_flag = False
    if 'step_size' in solver_options:
        if solver_options['step_size'] is None:
            solver_options['step_size'] = 1.0
            step_size_none_flag = True
    else:
        solver_options['step_size'] = 1.0
        step_size_none_flag = True


    # Make sure all keys in the dictionary are lower
    _solver_options = {}
    for k, v in solver_options.iteritems():
        _solver_options[k] = v

    # This should throw an error if the columns do not exist
    response_sframe = dataset.select_columns([response])
    # Check predictors
    if predictors is None:
      predictors = dataset.column_names()
      predictors.remove(response)

    # Check that everything is a string
    if not hasattr(predictors, '__iter__'):
        raise TypeError("Input 'predictors' must be an iterable")
    if not all([isinstance(x, str) for x in predictors]):
        raise TypeError("Invalid predictor key type: must be str")
    predictors_sframe = dataset.select_columns(predictors)

    # Note: Options must be a flat dictionary.
    opts = {'response'    : response_sframe,
            'predictors'  : predictors_sframe,
            'model_name'  : model_name,
            'solver'      : solver}
    opts.update(_solver_options)

    # Model checking
    ret = _graphlab.toolkits.main.run("regression_train_init", opts)
    opts.update(ret)
    model = ret['model']

    # Switching point for vw
    if(solver == 'vw' or solver == 'vowpal-wabbit'):

        # Make the SFrame into VW format
        required_columns = predictors + [response]
        sf = dataset.select_columns(required_columns)

        # Append options
        vw_options = {}
        for key in ['quadratic', 'bigram', 'regularization', 'step_size',
            'max_iters']:
            if key in solver_options:
                vw_options[key] = solver_options[key]
            else:
                vw_options[key] = DEFAULT_SOLVER_OPTIONS[key]

        if step_size_none_flag:
            vw_options['step_size'] = None
        # Call VW
        # Note: VW checks that the response columns has +1 or -1
        vw_model = _vw.create(sf, response,
                             quadratic=vw_options['quadratic'],
                             bigram = vw_options['bigram'],
                             regularization = vw_options['regularization'],
                             loss_function ='logistic',
                             learning_rate = vw_options['step_size'],
                             verbose = verbose,
                             num_passes = vw_options['max_iters'])


        model_file = vw_model.__proxy__['model_file']
        target = vw_model.__proxy__['target']

        ## Write the VW model to our model
        args = {'key'       : "model_file",
                'value'     : model_file,
                'model_name': model_name,
                'model'     : model}
        ret = _graphlab.toolkits.main.run("regression_append_model", args)
        model = ret["model"]

        args = {'key'       : "__repr__",
                'value'     : vw_model.__proxy__['__repr__'],
                'model_name': model_name,
                'model'     : model}
        ret = _graphlab.toolkits.main.run("regression_append_model", args)
        model = ret["model"]

        args = {'key'       : "rmse",
                'value'     : vw_model.__proxy__['rmse'],
                'model_name': model_name,
                'model'     : model}
        ret = _graphlab.toolkits.main.run("regression_append_model", args)
        model = ret["model"]

        args = {'key'       : "accuracy",
                'value'     : vw_model.__proxy__['accuracy'],
                'model_name': model_name,
                'model'     : model}
        ret = _graphlab.toolkits.main.run("regression_append_model", args)
        model = ret["model"]

        args = {'key'       : "target",
                'value'     : target,
                'model_name': model_name,
                'model'     : model}
        ret = _graphlab.toolkits.main.run("regression_append_model", args)
        model = ret["model"]
        args = {'key'       : "trained",
                'value'     : 1,
                'model_name': model_name,
                'model'     : model}
        ret = _graphlab.toolkits.main.run("regression_append_model", args)
        model = ret["model"]

    # Call all our solvers!
    else:
      raise ValueError("Unsupported solver")

    model = ret['model']
    return LogisticRegressionModel(model)


class LogisticRegressionModel(RegressionModel):
    """
    Use a logistic model to predict a response based on predictors.
    """
    def __init__(self, model_proxy):
        self.__proxy__ = model_proxy
        self.__name__ = "regression_logistic_regression"
    def _get_wrapper(self):
        def model_wrapper(model_proxy):
            return LogisticRegressionModel(model_proxy)
        return model_wrapper
