"""
"""
import graphlab.connect as _mt
import graphlab as _graphlab
from graphlab.toolkits.model import Model
from graphlab.data_structures.sframe import SArray as _SArray
import pandas as _pd
import json as _json
from graphlab import vowpal_wabbit as _vw

# TODO: List of todo's for this file
#------------------------------------------------------------------------------
# 1. Repr implementation for regression

class RegressionModel(Model):
    """
    The abstract class for GraphLab regression models. This class defines
    methods common to all regression models, but leaves unique model details to
    separate model classes.
    """

    def __str__(self):
        """
        Returns the type of model.

        Returns
        -------
        out : string
            The type of model.
        """
        return self.__class__.__name__

    def __repr__(self):
        """
        Returns a string description of the model, including (where relevant)
        the schema of the training data, description of the training data,
        training statistics, and model hyperparameters.

        Returns
        -------
        out : string
            A description of the model.
        """

        ret = []
        ret.append("Class %s" % self.__class__.__name__)
        ret.append("-----------------------------------------------------------")
        solver = self.get('solver')

        # Separte VW solver
        if solver == 'vw':

            ret.append(self.get('__repr__'))

        else:
            ret.append("# Examples : %s" % self.get('num_examples'))
            ret.append("# Features : %s" % self.get('num_features'))
            ret.append("")

            ret.append("Solver         : %s" % self.get('solver'))
            ret.append("Solver Status  : %s" % self.get('solver_status'))
            ret.append("")

            ret.append("Train time     : %s" % self.get('train_time'))
            ret.append("Training RMSE  : %s" % self.get('training_rmse'))
            ret.append("Gradient Norm  : %s" % self.get('norm_grad'))
            ret.append("Loss           : %s" % self.get('func_value'))
            ret.append("Iterations     : %s" % self.get('iters'))
            ret.append("Function evals : %s" % self.get('num_function_evaluations'))
            ret.append("Gradient evals : %s" % self.get('num_gradient_evaluations'))


        return '\n'.join(ret)


    def list_fields(self):
        """
        Get the current settings used to train the model, including size of the
        data, solver chosen, solver options chosen.

        Returns
        -------
        out : list
            A list of fields that can be queried using the ``get'' method.
        """

        _mt._get_metric_tracker().track('toolkit.regression.list_fields')
        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        response = _graphlab.toolkits.main.run('regression_list_keys', opts)
        return response.keys()

    def get(self, field):
        """
        Get the value of a particular field associated with the model. The list
        of fields that can be queried using the get method can be obtained
        using the list_fields method.

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out :
            The current value of the requested field.
        """

        _mt._get_metric_tracker().track('toolkit.regression.get')
        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'field': field}
        response = _graphlab.toolkits.main.run('regression_get_value',
            opts)
        return response['value']

    def summary(self):
        """Returns a summary of the model.

        Returns
        -------
        out : dict
            A dictionary with a summary of relevant statistics.
        """

        _mt._get_metric_tracker().track('toolkit.regression.summary')
        return self.__repr__()


    def _get_default_options(self):
        """
        A dictionary describing all the parameters available for the model.

        Each parameter consists of:

        Returns
        -------
        out : string
             A dictionary with key-value pairs containing the default options.

        """

        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        response = _graphlab.toolkits.main.run('regression_get_default_options',
                                                                          opts)
        return response

    def _get_current_options(self):
        """
        A dictionary describing all the parameters of the given model
        and their current value.
        Returns
        -------
        out : string
             A dictionary with key-value pairs containing the current options.
        """
        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        response = _graphlab.toolkits.main.run('regression_get_current_options',
                                                                         opts)
        return response


    def predict(self, dataset):
        """
        Return a score prediction for the provided training data set.

        Parameters
        ----------
        dataset : SFrame
            Dataset in the SAME format used for training.

        Returns
        -------
        out : DataFrame/SFrame:
            A DataFrame with model predictions.
        """

        _mt._get_metric_tracker().track('toolkit.regression.predict')
        if type(dataset) == _pd.DataFrame:
            dataset = _SFrame(dataset)

        # Get metadata from C++ object
        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'test_data': dataset}

        response = _graphlab.toolkits.main.run('regression_predict_init', opts)
        opts.update(response)
        print response['test_X'].column_names()
        solver = self.get('solver')

        # VW or Graphlab
        if solver == 'vw':

          # Reconstruct the VW object
          response = self.get("response")
          target     = self.get("target")
          model_file = self.get("model_file")
          vw_model_proxy  ={'response'        : response,
                           'target'           : target,
                           'command_line_args': '',
                           'model_file'       : model_file}

          vw_model = _vw.VWModel(vw_model_proxy)
          return vw_model.predict(dataset)


        else:
          response = _graphlab.toolkits.main.run('regression_predict', opts)
          return _SArray(None, _proxy=response['predicted'])


    def evaluate(self, observed, predicted, verbose=True):
        r"""

        The default metric used for model comparison is root-mean-squared error
        (RMSE).  Let :math:`y` and :math:`\hat{y}` denote vectors of length
        :math:`N` with observed/actual and predicted measurements. The RMSE is
        defined as:

        .. math::

            RMSE = \sqrt{\frac{1}{N} \sum_{i=1}^N (\widehat{y}_i - y_i)^2}

        Parameters
        ----------
        observed: SArray
            An SArray with the predicted values

        predicted: SArray
            An SArray with the observed values.

        verbose : bool, optional
            Enables verbose output. (Default is verbose).

        Returns
        -------
        out : dict
            Results from  model evaluation procedure.

        """

        _mt._get_metric_tracker().track('toolkit.regression.evaluate')
        opts = {'model': self.__proxy__,
                'observed': observed,
                'predicted': predicted,
                'model_name': self.__name__}

        response = _graphlab.toolkits.main.run('regression_evaluate_init', opts)
        opts.update(response)
        response = _graphlab.toolkits.main.run("regression_evaluate", opts)
        return response


    def training_stats(self):
        """
        Get information about model creation, e.g. time elapsed during model
        fitting, data loading, and more.

        Returns
        -------
        out : dict
            Statistics about model training, e.g. runtime.
        """
        
        _mt._get_metric_tracker().track('toolkit.regression.training_stats')
        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        solver = self.get('solver')

        if solver == 'vw':

            response = {'train_rmse': self.get('rmse'),
                        'train_accuracy': self.get('accuracy')}

        else:
            response = _graphlab.toolkits.main.run("regression_get_train_stats", opts)

        return response
