##\internal
"""@package graphlab.toolkits
This module defines the (internal) functions used by the supervised_learning_models.
"""
import graphlab as _graphlab

from graphlab.toolkits._model import Model
from graphlab.toolkits._internal_utils import _toolkits_select_columns
from graphlab.toolkits._internal_utils import _raise_error_if_not_sframe
from graphlab.toolkits._internal_utils import _map_unity_proxy_to_object

class SupervisedLearningModel(Model):
    """
    Supervised learning module to predict a target variable as a function of
    several feature variables.
    """
    def __init__(self, model_proxy = None, name = None):
        """__init__(self)"""
        self.__proxy__ = model_proxy
        self.__name__ = name

    def __str__(self):
        """
        Return a string description of the model to the ``print`` method.

        Returns
        -------
        out : string
            A description of the model.
        """
        return self.__class__.__name__

    def __repr__(self):
        """
        Returns a string description of the model, including (where relevant)
        the schema of the training data, description of the training data,
        training statistics, and model hyperparameters.

        Returns
        -------
        out : string
            A description of the model.
        """
        return self.__class__.__name__


    def summary(self):
        """
        Display a summary of the model, training options, and training
        statistics.
        """
        return self.__class__.__name__


    def get_current_options(self):
        """
        Return a dictionary with the options used to define and train the model.

        Returns
        -------
        out : dict
            Dictionary with options used to define and train the model.

        Examples
        -------

        >>> options = m.get_current_options()
        """

        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        return _graphlab.toolkits._main.run('supervised_learning_get_current_options',
            opts)

    def predict(self, dataset, missing_value_action = 'error',
                                      output_type='', options = {},
                                      **kwargs):
        """
        Return predictions for ``dataset``, using the trained supervised_learning
        model. Predictions are generated as class labels (0 or
        1).

        Parameters
        ----------
        dataset : SFrame
            Dataset of new observations. Must include columns with the same
            names as the features used for model training, but does not require
            a target column. Additional columns are ignored.
        missing_value_action: str, optional
            Action to perform when missing values are encountered. This can be
            one of:

            - 'impute': Proceed with evaluation by filling in the missing
                        values with the mean of the training data. Missing
                        values are also imputed if an entire column of data is
                        missing during evaluation.
            - 'error' : Do not proceed with prediction and terminate with
                        an error message.
        output_type : str, optional
            output type that maybe needed by some of the toolkits
        options : dict
            additional options to be passed in to prediction
        kwargs : dict
            additional options to be passed into prediction
        Returns
        -------
        out : SArray
            An SArray with model predictions.
        """

        _raise_error_if_not_sframe(dataset, "dataset")

        options = options.copy()
        options.update(kwargs)
        options.update({'model': self.__proxy__,
                        'model_name': self.__name__,
                        'dataset': dataset,
                        'missing_value_action' : missing_value_action,
                        'output_type' : output_type
                        })

        target = _graphlab.toolkits._main.run('supervised_learning_predict', options)
        return _map_unity_proxy_to_object(target['predicted'])


    def evaluate(self, dataset, metric = "auto", missing_value_action = 'error',
                                                      options = {}, **kwargs):
        """
        Evaluate the model by making predictions of target values and comparing
        these to actual values.

        Parameters
        ----------
        dataset : SFrame
            Dataset in the same format used for training. The columns names and
            types of the dataset must be the same as that used in training.
        missing_value_action: str, optional
            Action to perform when missing values are encountered. This can be
            one of:

            - 'impute': Proceed with evaluation by filling in the missing
                        values with the mean of the training data. Missing
                        values are also imputed if an entire column of data is
                        missing during evaluation.
            - 'error' : Do not proceed with prediction and terminate with
                        an error message.
        options : dict
            additional options to be passed in to prediction
        kwargs : dict
            additional options to be passed into prediction
        """
        _raise_error_if_not_sframe(dataset, "dataset")
        options = options.copy()
        options.update(kwargs)
        options.update({'model': self.__proxy__,
                        'dataset': dataset,
                        'model_name': self.__name__,
                        'missing_value_action' : missing_value_action,
                        'metric' : metric
                        })
        results = _graphlab.toolkits._main.run('supervised_learning_evaluate', options)
        returned_metrics = results.keys()
        return _map_unity_proxy_to_object(results)

    def _training_stats(self):
        """
        Return a dictionary containing statistics collected during model
        training. These statistics are also available with the ``get`` method,
        and are described in more detail in the documentation for that method.

        Notes
        -----
        """
        opts = {'model': self.__proxy__, 'model_name': self.__name__}
        results = _graphlab.toolkits._main.run("supervised_learning_get_train_stats", opts)
        return _map_unity_proxy_to_object(results)

    def list_fields(self):
        """
        List the fields stored in the model, including data, model, and
        training options. Each field can be queried with the ``get`` method.

        Returns
        -------
        out : list
            List of fields queryable with the ``get`` method.

        Examples
        --------
        >>> fields =  m.list_fields()
        """
        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        response = _graphlab.toolkits._main.run('supervised_learning_list_keys', opts)
        return sorted(response.keys())

    def get(self, field):
        """
        Get the value of a given field.

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out : [various]
            The current value of the requested field.

        See Also
        --------
        list_fields
        """
        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'field': field}
        response = _graphlab.toolkits._main.run('supervised_learning_get_value',
                                               opts)
        return _map_unity_proxy_to_object(response['value'])


class Classifier(SupervisedLearningModel):
    """
    Clasifier module to predict a discrete target variable as a function of
    several feature variables.
    """

    def classify(self, dataset, missing_value_action = 'error'):
        """
        Return predictions for ``dataset``, using the trained supervised_learning
        model. Predictions are generated as class labels (0 or
        1).

        Parameters
        ----------
        dataset : SFrame
            Dataset of new observations. Must include columns with the same
            names as the features used for model training, but does not require
            a target column. Additional columns are ignored.
        missing_value_action: str, optional
            Action to perform when missing values are encountered. This can be
            one of:

            - 'impute': Proceed with evaluation by filling in the missing
                        values with the mean of the training data. Missing
                        values are also imputed if an entire column of data is
                        missing during evaluation.
            - 'error' : Do not proceed with prediction and terminate with
                        an error message.
        Returns
        -------
        out : SFrame
            An SFrame with model predictions.
        """

        _raise_error_if_not_sframe(dataset, "dataset")
        options = {}
        options.update({'model': self.__proxy__,
                        'model_name': self.__name__,
                        'dataset': dataset,
                        'missing_value_action' : missing_value_action,
                        })
        target = _graphlab.toolkits._main.run('supervised_learning_classify',
                                                                    options)
        return _map_unity_proxy_to_object(target['classify'])



def create(dataset, target, model_name, features=None,
           validation_set = None, verbose = True, **kwargs):
    """
    Create a :class:`~graphlab.toolkits.SupervisedLearningModel`,

    This is generic function that allows you to create any model that
    implements SupervisedLearningModel This function is normally not called, call
    specific model's create function instead

    Parameters
    ----------
    dataset : SFrame
        Dataset for training the model.

    target : string
        Name of the column containing the target variable. The values in this
        column must be 0 or 1, of integer type.

    model_name : string
        Name of the model

    features : list[string], optional
        List of feature names used by feature column

    validation_set : SFrame, optional
        The validation set that is used to watch the validation result as
        boosting progress.

    verbose : boolean
        whether print out messages during training

    kwargs : dict
        Additional parameter options that can be passed
    """

    _raise_error_if_not_sframe(dataset, "training dataset")

    # Target
    target_sframe = _toolkits_select_columns(dataset, [target])

    # Features
    if features is None:
        features = dataset.column_names()
        features.remove(target)
    if not hasattr(features, '__iter__'):
        raise TypeError("Input 'features' must be a list.")
    if not all([isinstance(x, str) for x in features]):
        raise TypeError("Invalid feature %s: Feature names must be of type str" % x)
    features_sframe = _toolkits_select_columns(dataset, features)


    options = {}
    _kwargs = {}
    for k in kwargs:
      _kwargs[k.lower()] = kwargs[k]
    options.update(_kwargs)
    options.update({'target': target_sframe,
                    'features': features_sframe,
                    'model_name': model_name})

    if validation_set is not None:
        options.update({
            'features_validation' : _toolkits_select_columns(validation_set, features),
            'target_validation' : _toolkits_select_columns(validation_set, [target])})

    ret = _graphlab.toolkits._main.run("supervised_learning_train", options, verbose=verbose)
    model = SupervisedLearningModel(ret['model'], model_name)
    return model

def create_with_model_selector(dataset, target, model_selector,
    features = None, verbose = True):
    """
    Create a :class:`~graphlab.toolkits.SupervisedLearningModel`,

    This is generic function that allows you to create any model that
    implements SupervisedLearningModel This function is normally not called, call
    specific model's create function instead

    Parameters
    ----------
    dataset : SFrame
        Dataset for training the model.

    target : string
        Name of the column containing the target variable. The values in this
        column must be 0 or 1, of integer type.

    model_name : string
        Name of the model

    model_selector: function
        Provide a model selector.

    features : list[string], optional
        List of feature names used by feature column

    verbose : boolean
        whether print out messages during training

    """

    # Error checking
    _raise_error_if_not_sframe(dataset, "training dataset")
    if features is None:
        features = dataset.column_names()
        if target in features:
            features.remove(target)
    if not hasattr(features, '__iter__'):
        raise TypeError("Input 'features' must be a list.")
    if not all([isinstance(x, str) for x in features]):
        raise TypeError("Invalid feature %s: Feature names must be of type str" % x)

    # Sample the data
    features_sframe = _toolkits_select_columns(dataset, features)
    if features_sframe.num_rows() > 1e5:
        fraction = 1.0 * 1e5 / features_sframe.num_rows()
        features_sframe = features_sframe.sample(fraction, seed = 0)

    # Run the model selector.
    selected_model_name = model_selector(features_sframe)

    if (selected_model_name == 'neuralnet_classifier'):

      model = _graphlab.classifier.neuralnet_classifier.create(dataset,
                                target, features = features, verbose = verbose)
      return model

    else:
      # Create the model
      model = create(dataset,
                     target,
                     selected_model_name,
                     features = features,
                     verbose = verbose)

      # Return the model
      if selected_model_name == 'boosted_trees_regression':
          return _graphlab.boosted_trees_regression.BoostedTreesRegression(\
                                                                model.__proxy__)
      elif selected_model_name == 'regression_linear_regression':
          return _graphlab.linear_regression.LinearRegression(\

                                                                model.__proxy__)
      elif selected_model_name == 'boosted_trees_classifier':
          return _graphlab.boosted_trees_classifier.BoostedTreesClassifier(\
                                                                model.__proxy__)
      elif selected_model_name == 'classifier_logistic_regression':
          return _graphlab.logistic_classifier.LogisticClassifier(\
                                                                model.__proxy__)
      elif selected_model_name == 'classifier_svm':
          return _graphlab.svm_classifier.SVMClassifier(model.__proxy__)
      else:
          raise ToolkitError, "Internal error: Incorrect model returned."
