"""
"""

import graphlab as _graphlab
from graphlab.toolkits.model import Model
from graphlab.data_structures.sframe import SFrame as _SFrame
import pandas as _pd
import json as _json


class RecommenderModel(Model):
    """
    The abstract class for GraphLab recommender system models. This class
    defines methods common to all recommender system models, but leaves unique
    model details to separate model classes.
    """

    def list_fields(self):
        """
        Get the current settings of the model. The keys depend on the type of
        model.

        Returns
        -------
        out : dict
            A dictionary that maps model options to the values used during
            training.
        """
        return self._get_default_options().keys()

    def get(self, field):
        """
        Get the value of a particular field.

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out : string
            The current value of the requested field.
        """
        opts = {'model': self.__proxy__, 'field': field}
        response = _graphlab.toolkits.main.run('recsys_get_option_value', opts)
        return response['value']

    def summary(self):
        """Returns a summary of training statistics."""
        return self.training_stats()

    def __str__(self):
        """
        Returns the type of model.

        Returns
        -------
        out : string
            The type of model.
        """
        return self.__class__.__name__

    def __repr__(self):
        """
        Returns a string description of the model, including (where relevant)
        the schema of the training data, description of the training data,
        training statistics, and model hyperparameters.

        Returns
        -------
        out : string
            A description of the model.
        """

        stats = self.training_stats()
        options = self._get_current_options()

        # Print model type and data schema
        ret = self.__class__.__name__ + '\n'

        # If the model is untrained, there will not be an elapsed_time
        is_trained = 'elapsed_time' in stats.keys()

        # TODO: Refactor into functions
        if is_trained:

            ret += "\nSchema\n"
            for k in ["user_column", "item_column", "target_column"]:
                if k in options.keys():
                    if options[k] not in ['', 'not_specified']:
                        ret += "  %-14s %s\n" % \
                            (k.replace('_', ' ') + ':', options[k])
                    del options[k]

            # Statistics about the training and validation set
            ks = ["num_obs_train",
                  "num_unique_users_train",
                  "num_unique_items_train",
                  "num_obs_validate",
                  "num_unique_users_validate",
                  "num_unique_items_validate"]
            if all([k in stats.keys() for k in ks]):
                ret += "\nStatistics\n"
                ret += "  %-17s %11d obs %11d users %11d items\n" % \
                    ("Training set:", stats[ks[0]], stats[ks[1]], stats[ks[2]])

                # Only show validation description if there are observations
                # present
                if stats["num_obs_validate"] > 0:
                    ret += "  %-17s %11d obs %11d users %11d items\n" % \
                        ("Validation set:", stats[ks[3]],
                         stats[ks[4]], stats[ks[5]])

            # Training summary
            ks = ["elapsed_time", "random_seed", "holdout_probability",
                  "data_load_elapsed_time"]
            if all([k in stats.keys()]):
                ret += "\nTraining summary\n"
                ret += "  %-20s %fs\n" % ("training time:", stats["elapsed_time"])
                ret += "  %-20s %fs\n" % ("data load time:", stats["data_load_elapsed_time"])

                if "validation_metrics_elapsed_time" in stats.keys():
                    ret += "  %-20s %fs\n" % \
                           ("metrics time:", stats["validation_metrics_elapsed_time"])

                # Print holdout probability only if nonzero
                p = options['holdout_probability']
                if p > 0:
                    ret += "  heldout %.3f of data for validation\n" % p

                # Print random seed only if nonzero
                if options['random_seed'] != 0:
                    ret += "  random seed: %20d \n" % options['random_seed']

            # If available, print performance statistics
            for k in stats.keys():
                if 'rmse' in k or 'precision' in k or 'recall' in k:
                    ret += '  %-20s %-10s\n' % (k + ':', stats[k])

        else:
            ret += '\nThis model has yet to be trained.\n'

        # Remove any options that should not be shown under "Settings"
        to_ignore = ['random_seed', 'holdout_prob', 'user_column',
                     'item_column', 'target_column']
        for k in to_ignore:
            if k in options:
                del options[k]

        # Print remaining hyperparameters
        # TODO: Get max width of keys and use it for fixed width formatting.
        if len(options) > 0:
            ret += "\nSettings\n"
            for k, v in options.iteritems():
                ret += "  %-22s %-30s\n" % (k.replace('_', ' ') + ':', str(v))

        return ret

    def _get_default_options(self):
        """
        A dictionary describing all the parameters of the given model.
        For each parameter there may be:
          name
          description
          type (REAL, CATEGORICAL)
          default_value
          possible_values: for reals this includes a lower and upper bound.
                           for categoricals this is a list of possible values.
        """

        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run('recsys_get_default_options', opts)
        for k in response.keys():
            response[k] = _json.loads(response[k])
        return response

    def _get_current_options(self):
        """
        A dictionary describing all the parameters of the given model
        and their current setting.
        """
        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run('recsys_get_current_options', opts)
        return response

    def _set_current_options(self, options):
        """
        Set current options for a model.

        Parameters
        ----------
        options : dict
            A dictionary of the desired option settings. The key should be the name
            of the option and each value is the desired value of the option.
            The possible options are all those returne dy get_default_options().
        """
        opts = self._get_current_options()
        opts.update(options)
        opts['model'] = self.__proxy__
        response = _graphlab.toolkits.main.run('recsys_set_current_options', opts)
        return response

    def predict(self, dataset):
        """
        Return a score prediction for the user ids and item ids in the provided
        data set.

        Parameters
        ----------
        dataset : SFrame
            Dataset in the same form used for training.

        Returns
        -------
        out : DataFrame/SFrame:
            A DataFrame with model predictions, with columns user_id, item_id, and
            score
        """

        if type(dataset) == _pd.DataFrame:
            dataset = _SFrame(dataset)

        # Get metadata from C++ object
        opts = {'data_to_predict': dataset, 'model': self.__proxy__}

        # Call the C++ function for recommender_model
        response = _graphlab.toolkits.main.run('recsys_predict', opts)
        result = _SFrame(None, _proxy=response['data'])
        return result

    def recommend(self, dataset, k=30):
        """
        Return the k best recommend items for each user, excluding items for
        which the user already has an entry. Typically this is used after the
        model has been trained. For a RecommenderModel object m, we obtain the
        highest ranked items for each user in our training set by running:

        >>> m.train(data, 'user', 'item')
        >>> recs = m.recommend(data)

        The returned object has four columns: user, item, score, and rank (0 is
        the highest). The observations in **dataset** are automatically excluded
        from the returned recommendations. In some situations we may have new
        observations that we want to be able to use to make recommendations, and
        the recommend function accepts a *new* dataset to acoommodate this.

        >>> recs = m.recommend(new_data)

        Parameters
        ----------
        dataset : SFrame/pandas.DataFrame
            A dataset of observed data for which recommendations are desired.
            The k highest ranked items are returned for each user in this
            dataset, and all observations in this data set will be excluded from
            the recommendations.

        k : int, optional
            The number of items to return for each user (default is 30).

        Returns
        -------
        out : SFrame
            A SFrame with the top ranked items for each user. The
            columns are: *user*, *item*, *score*, and *rank*.
        """

        if type(dataset) == _pd.DataFrame:
            dataset = _SFrame(dataset)
        elif type(dataset) != _SFrame:
            raise NotImplemented

        opt = {}
        opt["model"] = self.__proxy__
        opt["top_k"] = k
        opt['available_data'] = dataset
        response = _graphlab.toolkits.main.run('recsys_recommend', opt)
        recs = _SFrame(None, _proxy=response['data'])
        return recs

    def training_stats(self):
        """
        Get information about model creation, e.g. time elapsed during
        model fitting, data loading, and more.

        Returns
        -------
        out : dict
            Statistics about model training, e.g. runtime.
        """
        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run("recsys_get_stats", opts)
        return response

    def get_validation_split(self, dataset):
        """
        Retrieve the subset of the training dataset that was used internally for
        validation testing.

        Parameters
        ----------
        dataset : DataFrame/SFrame
            A dataset in the same format as the one used during training.

        Returns
        -------
        out : dict
            A dictionary with two items. The *training_data* item contains the
            data used for training, and the *validation_data* item contains the
            data used for internal validation.
        """

        if type(dataset) == _pd.DataFrame:
            dataset = _SFrame(dataset)
        elif type(dataset) != _SFrame:
            raise NotImplementedError

        opts = {'model': self.__proxy__, 'data': dataset}
        response = _graphlab.toolkits.main.run('recsys_get_validation_set', opts)
        ret = {'training_data': _SFrame(None, _proxy=response['training_data']),
               'validation_data': _SFrame(None, _proxy=response['validation_data'])}
        return ret

    def _evaluate_precision_recall(self, dataset, cutoffs=[5, 10, 15],
                                   skip_set=None):
        """
        Compute a model's precision and recall for a particular dataset.

        Parameters
        ----------
        dataset : pandas.DataFrame/SFrame
            A DataFrame in the same format as the one used during training.

        cutoffs : list
            A list of cutoff values for which one wants to evaluate precision
            and recall, i.e. the value of k in precision@k.

        skip_set : SFrame/DataFrame
            A dataset in the same format as the one used for training. Each (user,
            item) pair in the skip_set will be excluded from the set of
            recommendations.

        Returns
        -------
        out : dict
            A dictionary containing two items, both SFrames. Item *ranked_items*
            contains the recommendations for each user, and item *results* contains
            the precision and recall at each cutoff value.
        """

        if type(dataset) == _pd.DataFrame:
            dataset = _SFrame(dataset)
        elif type(dataset) != _SFrame:
            raise NotImplementedError
        if skip_set is None:
            skip_set = _SFrame(_pd.DataFrame())
        if type(skip_set) == _pd.DataFrame:
            skip_set = _SFrame(skip_set)
        elif type(skip_set) != _SFrame:
            raise NotImplementedError

        opts = {}
        opts['model'] = self.__proxy__
        opts['validation_data'] = dataset
        opts['available_data'] = skip_set
        opts['cutoffs'] = _pd.DataFrame({'cutoff': cutoffs})
        opts['top_k'] = max(cutoffs)
        response = _graphlab.toolkits.main.run("recsys_precision_recall", opts)
        res = {'ranked_items': _SFrame(None, _proxy=response['ranked_items']),
               'results': _SFrame(None, _proxy=response['results'])}
        return res

    def _evaluate_rmse(self, dataset):
        """
        Evaluate the prediction error for each user-item pair in the given data
        set.

        Parameters
        ----------
        dataset : SFrame/pandas.DataFrame
            A DataFrame in the same format as the one used during training.

        Returns
        -------
        out : dict
            A dictionary with three items: *rmse_by_user* and *rmse_by_item*, which
            are SFrames; and *rmse_overall*, which is a float.
        """

        if type(dataset) == _pd.DataFrame:
            dataset = _SFrame(dataset)

        opts = {}
        opts['model'] = self.__proxy__
        opts['dataset'] = dataset
        response = _graphlab.toolkits.main.run('recsys_rmse', opts)
        rmse_by_user = _SFrame(None, _proxy=response['rmse_by_user'])
        rmse_by_item = _SFrame(None, _proxy=response['rmse_by_item'])
        rmse_overall = response['rmse_overall']

        return {'rmse_by_user': rmse_by_user,
                'rmse_by_item': rmse_by_item,
                'rmse_overall': rmse_overall}

    def evaluate(self, dataset, verbose=True, **kwargs):
        r"""Evaluate the model's ability to make predictions for another
        dataset. If the model is trained to predict a particular target, the
        default metric used for model comparison is root-mean-squared error
        (RMSE). Suppose :math:`y` and :math:`\widehat{y}` are vectors of length
        :math:`N` with actual and predicted ratings. Then the RMSE is

        .. math::

            RMSE = \sqrt{\frac{1}{N} \sum_{i=1}^N (\widehat{y}_i - y_i)^2}

        If the model was not trained on a target column during, the default
        metrics for model comparison are precision and recall.

        Parameters
        ----------
        dataset : SFrame/pandas.DataFrame
            An SFrame or DataFrame that is in the same format as provided for
            training.

        verbose : bool, optional
            Enables verbose output. Default is verbose.

        kwargs : dict
            Parameters passed on to internal function
            `_evaluate_precision_recall`.

        Returns
        -------
        out : SFrame
            Results from  model evaluation procedure.

        Notes
        -----
            If the model is trained on a target (i.e. RMSE is the evaluation
            criterion), a dictionary with three items is returned: items
            *rmse_by_user* and *rmse_by_item* are SFrames with per-user and
            per-item RMSE, while *rmse_overall* is the overall RMSE (a float).
            If the model is trained without a target (i.e. precision and recall
            are the evaluation criteria) and SFRame is returned with both of
            these metrics for each user at several cutoff values.
        """

        # If the model does not have a target column, compute prec-recall.
        if self.get('target_column') == '':
            results = self._evaluate_precision_recall(dataset, **kwargs)['results']
            if verbose:
                print "\nPrecision and recall summary statistics by cutoff"
                print results.to_dataframe().groupby('cutoff').describe()[['precision', 'recall']]
        else:
            results = self._evaluate_rmse(dataset)
            if verbose:
                print "\nOverall RMSE: ", results['rmse_overall']
                print "\nPer User RMSE (best)"
                print results['rmse_by_user'].topk('rmse', 1, reverse=True)
                print "\nPer User RMSE (worst)"
                print results['rmse_by_user'].topk('rmse', 1)
                print "\nPer Item RMSE (best)"
                print results['rmse_by_item'].topk('rmse', 1, reverse=True)
                print "\nPer Item RMSE (worst)"
                print results['rmse_by_item'].topk('rmse', 1)

        return results
