"""
Methods for creating and using a supervised learning model.
"""
import graphlab as _graphlab
from graphlab.toolkits.model import Model
from graphlab.data_structures.sframe import SArray as _SArray
from graphlab.data_structures.sframe import SFrame as _SFrame
from graphlab.deps import pandas as _pandas, HAS_PANDAS as _HAS_PANDAS


class SupervisedLearningModel(Model):
    """
    Supervised learning module to predict a target variable as a function of
    several feature variables.
    """

    def __str__(self):
        """
        Return a string description of the model to the ``print`` method.

        Returns
        -------
        out : string
            A description of the model.
        """
        return self.__class__.__name__

    def __repr__(self):
        """
        Returns a string description of the model, including (where relevant)
        the schema of the training data, description of the training data,
        training statistics, and model hyperparameters.

        Returns
        -------
        out : string
            A description of the model.
        """
        raise NotImplementedError


    def summary(self):
        """
        Display a summary of the model, training options, and training
        statistics.
        """
        raise NotImplementedError

    def get_default_options(self):
        """
        Return a dictionary with the model's default options.

        Returns
        -------
        out : dict
            Dictionary with model default options.
        """

        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        return _graphlab.toolkits.main.run('supervised_learning_get_default_options',
            opts)

    def get_current_options(self):
        """
        Return a dictionary with the options used to define and train the model.

        Returns
        -------
        out : dict
            Dictionary with options used to define and train the model.
        """

        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        return _graphlab.toolkits.main.run('supervised_learning_get_current_options',
            opts)

    def predict(self, dataset):
        """
        Return predictions for ``dataset``, using the trained supervised_learning
        classification model. Predictions are generated as class labels (0 or
        1).

        Parameters
        ----------
        dataset : SFrame
            Dataset of new observations. Must include columns with the same
            names as the features used for model training, but does not require
            a target column. Additional columns are ignored.

        Returns
        -------
        out : SArray
            An SArray with model predictions.
        """

        if _HAS_PANDAS and type(dataset) is _pandas.DataFrame:
            dataset = _SFrame(dataset)
        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'dataset': dataset}

        init_opts = _graphlab.toolkits.main.run('supervised_learning_predict_init', opts)
        opts.update(init_opts)
        target = _graphlab.toolkits.main.run('supervised_learning_predict', opts)
        return _SArray(None, _proxy=target['predicted'])

    def evaluate(self, dataset):
        """
        Evaluate the model by making predictions of target values and comparing
        these to actual values.
        """
        raise NotImplementedError


    def training_stats(self):
        """
        Return a dictionary containing statistics collected during model
        training. These statistics are also available with the ``get`` method,
        and are described in more detail in the documentation for that method.

        Notes
        -----
        """
        opts = {'model': self.__proxy__, 'model_name': self.__name__}
        return _graphlab.toolkits.main.run("supervised_learning_get_train_stats", opts)

    def list_fields(self):
        """
        List the fields stored in the model, including data, model, and
        training options. Each field can be queried with the ``get`` method.

        Returns
        -------
        out : list
            List of fields queryable with the ``get`` method.
        """
        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        response = _graphlab.toolkits.main.run('supervised_learning_list_keys', opts)
        return sorted(response.keys())
