"""
"""
import graphlab.connect as _mt
import graphlab as _graphlab
from graphlab.toolkits.model import Model
from graphlab.data_structures.sframe import SArray as _SArray
from graphlab.data_structures.sframe import SFrame as _SFrame
import pandas as _pd
import json as _json
from graphlab import vowpal_wabbit as _vw


class RegressionModel(Model):
    """
    The abstract class for GraphLab regression models. This class defines
    methods common to all regression models, but leaves unique model details to
    separate model classes.
    """

    def __str__(self):
        """
        Returns the type of model.

        Returns
        -------
        out : string
            The type of model.
        """
        return self.__class__.__name__

    def __repr__(self):
        """
        Returns a string description of the model, including (where relevant)
        the schema of the training data, description of the training data,
        training statistics, and model hyperparameters.

        Returns
        -------
        out : string
            A description of the model.
        """
        return ''


    def list_fields(self):
        """
        Get the current settings used to train the model, including size of the
        data, solver chosen, solver options chosen.

        Returns
        -------
        out : list
            A list of fields that can be queried using the ``get'' method.
        """
        solver = self.get('solver')
        if solver == 'vw':
          return _vw.VWModel(self.__proxy__).list_fields()

        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        response = _graphlab.toolkits.main.run('regression_list_keys', opts)
        return response.keys()

    def get(self, field):
        """
        Get the value of a particular field associated with the model. The list
        of fields that can be queried using the get method can be obtained
        using the list_fields method.

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out :
            The current value of the requested field.
        """
        if self.use_vw:
            if field == 'coefficients':
                raise ValueError("Models trained with Vowpal Wabbit do not " + \
                                 "contain coefficient values. Please use a " + \
                                 "different solver to obtain the coefficients.")
            else:
                return _vw.VWModel(self.__proxy__).get(field)
        
        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'field': field}
        response = _graphlab.toolkits.main.run('regression_get_value',
            opts)
        if field == 'coefficients':
            return _SFrame(None, _proxy=response['value'])
        else:
            return response['value']

    def summary(self):
        """Returns a summary of the model.

        Returns
        -------
        out : dict
            A dictionary with a summary of relevant statistics.
        """

        solver = self.get('solver')
        if solver == 'vw':
             _vw.VWModel(self.__proxy__).summary()
             return

        coefs = self.get('coefficients')
        top_coefs = coefs.topk('Coefficient', k=5)
        top_coefs = top_coefs[top_coefs['Coefficient'] > 0]

        bottom_coefs = coefs.topk('Coefficient', k=5, reverse=True)
        bottom_coefs = bottom_coefs[bottom_coefs['Coefficient'] < 0]
        
        print ""
        print "                    Model summary                       "
        print "--------------------------------------------------------"
        print self.__repr__()
                
        print "             Strongest positive coefficients            "
        print "--------------------------------------------------------"
        if len(top_coefs) > 0:
            print _SFrame(top_coefs)
        else:
            print "[No positive coefficients]"

        print "             Strongest negative coefficients            "
        print "--------------------------------------------------------"
        if len(bottom_coefs) > 0:
            print _SFrame(bottom_coefs)
        else:
            print "[No negative coefficients]"

    def get_default_options(self):
        """
        A dictionary describing all the parameters available for the model.

        Each parameter consists of:

        Returns
        -------
        out : string
             A dictionary with key-value pairs containing the default options.

        """

        if self.get('solver') == 'vw':
          return {'max_iterations': 10, 'step_size': 1.0}

        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        return _graphlab.toolkits.main.run('regression_get_default_options',
            opts)

    def get_current_options(self):
        """
        A dictionary describing all the parameters of the given model
        and their current value.
        Returns
        -------
        out : string
             A dictionary with key-value pairs containing the current options.
        """

        if self.get('solver') == 'vw':
          max_iterations = _vw.VWModel(self.__proxy__).get('max_iterations') 
          step_size = _vw.VWModel(self.__proxy__).get('step_size') 
          return {'max_iterations': max_iterations, 'step_size': step_size}

        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        return _graphlab.toolkits.main.run('regression_get_current_options',
            opts)

    def predict(self, dataset):
        """
        Return a score prediction for the provided data set.

        Parameters
        ----------
        dataset : SFrame
            Dataset in the SAME format used for training, except that it need not 
            contain a response column.

        Returns
        -------
        out : SArray:
            An SArray with model predictions.
        """

        if type(dataset) == _pd.DataFrame:
            dataset = _SFrame(dataset)

        # special handling of vw model
        if self.get('solver') == 'vw':
            m = _vw.VWModel(self.__proxy__)
            predictions = m.predict(dataset)
            return predictions

        # Get metadata from C++ object
        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'test_data': dataset,
                'output_type': ""}

        response = _graphlab.toolkits.main.run('regression_predict_init', opts)
        opts.update(response)
        response = _graphlab.toolkits.main.run('regression_predict', opts)
        return _SArray(None, _proxy=response['predicted'])


    def evaluate(self, dataset):
        r"""

        Evaluates the model by first making predictions and then comparing
        these predictions to the ground truth. The default metric used for
        model comparison is root-mean-squared error (RMSE). Let :math:`y` and
        :math:`\hat{y}` denote vectors of length :math:`N` (number of examples)
        with observed/actual and predicted measurements. The RMSE is defined
        as:

        .. math::

            RMSE = \sqrt{\frac{1}{N} \sum_{i=1}^N (\widehat{y}_i - y_i)^2}

        Parameters
        ----------
        dataset : SFrame
            Dataset in the same format used for training. The columns names and
            types of the dataset must be the same as that used in training.

        Returns
        -------
        out : dict
            Results from model evaluation procedure.

        """
        if self.get('solver') == 'vw':
            m = _vw.VWModel(self.__proxy__)
            evaluation_results = m.evaluate(dataset)
            return evaluation_results

        opts = {'model': self.__proxy__,
                'dataset': dataset,
                'model_name': self.__name__}

        response = _graphlab.toolkits.main.run('regression_evaluate_init', opts)
        opts.update(response)
        response = _graphlab.toolkits.main.run("regression_evaluate", opts)
        opts.update(response)
        return _graphlab.toolkits.main.run("regression_get_evaluate_stats", opts)


    def training_stats(self):
        """
        Get information about model creation, e.g. time elapsed during model
        fitting, data loading, and more.

        Returns
        -------
        out : dict
            Statistics about model training, e.g. runtime.
        """
        
        opts = {'model': self.__proxy__, 'model_name': self.__name__}
        solver = self.get('solver')

        if solver == 'vw':
            return _vw.VWModel(self.__proxy__).training_stats()
        else:
            return _graphlab.toolkits.main.run("regression_get_train_stats",
                opts)
