"""
"""
import graphlab.connect as _mt
import graphlab as _graphlab
from graphlab.toolkits.model import Model
from graphlab.data_structures.sarray import SArray as _SArray
from graphlab.data_structures.sframe import SFrame as _SFrame
import pandas as _pd
import json as _json
import array
import numpy as _np
import logging as _logging

class RecommenderModel(Model):
    """
    The abstract class for GraphLab recommender system models. This class
    defines methods common to all recommender system models, but leaves unique
    model details to separate model classes.
    """

    def list_fields(self):
        """
        Get the current settings of the model. The keys depend on the type of
        model.

        Returns
        -------
        out : dict
            A dictionary that maps model options to the values used during
            training.
        """
        return self._get_current_options().keys()

    def get(self, field):
        """
        Get the value of a particular field.

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out : string
            The current value of the requested field.
        """
        opts = {'model': self.__proxy__, 'field': field}
        response = _graphlab.toolkits.main.run('recsys_get_option_value', opts)
        return response['value']

    def summary(self):
        """Returns a summary of statistics about the model, the data
        used to train the model, and the training performance.

        Returns
        -------
        out : dict
            A dictionary containing model parameters, summary statistics of the
            data set, and summary statistics about training.

        """
        _mt._get_metric_tracker().track('toolkit.recsys.summary')
        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run('recsys_summary', opts)
        return response

    def __str__(self):
        """
        Returns the type of model.

        Returns
        -------
        out : string
            The type of model.
        """
        return self.__class__.__name__

    def __repr__(self):
        """
        Returns a string description of the model, including (where relevant)
        the schema of the training data, description of the training data,
        training statistics, and model hyperparameters.

        Returns
        -------
        out : string
            A description of the model.
        """

        stats = self.summary()
        options = self._get_current_options()

        # Print model type and data schema
        ret = self.__class__.__name__ + '\n'

        # If the model is untrained, there will not be an elapsed_time
        is_trained = 'elapsed_time' in stats.keys()

        # TODO: Refactor into functions
        if is_trained:

            ret += "\nSchema\n"
            for k in ["user_column", "item_column", "target_column"]:
                if k in options.keys():
                    if options[k] not in ['', 'not_specified'] and options[k] is not None:
                        ret += "  %-14s %s\n" % \
                            (k.replace('_', ' ') + ':', options[k])
                    del options[k]

            # Statistics about the training and validation set
            ret += "\nStatistics\n"
            ks = ["num_obs_train",
                  "num_unique_users",
                  "num_unique_items"]
            if all([k in stats.keys() for k in ks]):
                ret += "  %-17s %11d obs %11d users %11d items\n" % \
                    ("Training set:", stats[ks[0]], stats[ks[1]], stats[ks[2]])

            if len(stats['user_side_data_column_names']) > 0:
                ret += "  user side data:\n"
                colnames = stats['user_side_data_column_names']
                coltypes = stats['user_side_data_column_types']
                assert len(colnames) == len(coltypes), \
                    "Error importing user side data: mismatch between the \
                     number of column names and column types."

                for i in range(len(colnames)):
                    ret += '    {0} : {1}\n'.format(colnames[i], coltypes[i])

            if len(stats['item_side_data_column_names']) > 0:
                ret += "  item side data:\n"
                colnames = stats['item_side_data_column_names']
                coltypes = stats['item_side_data_column_types']
                assert len(colnames) == len(coltypes), \
                    "Error importing item side data: mismatch between the \
                     number of column names and column types."

                for i in range(len(colnames)):
                    ret += '    {0} : {1}\n'.format(colnames[i], coltypes[i])


            # Training summary
            ks = ["elapsed_time",
                  "random_seed",
                  "data_load_elapsed_time"]
            if any([k in stats.keys()]):
                ret += "\nTraining summary\n"
            if 'elapsed_time' in stats.keys():
                ret += "  %-20s %fs\n" % ("training time:", stats["elapsed_time"])
            if 'data_load_elapsed_time' in stats.keys():
                ret += "  %-20s %fs\n" % ("data load time:", stats["data_load_elapsed_time"])

            if "validation_metrics_elapsed_time" in stats.keys():
                ret += "  %-20s %fs\n" % \
                           ("metrics time:", stats["validation_metrics_elapsed_time"])

            # If available, print performance statistics
            for k in stats.keys():
                if 'rmse' in k or 'precision' in k or 'recall' in k:
                    ret += '  %-20s %-10s\n' % (k + ':', stats[k])

        else:
            ret += '\nThis model has yet to be trained.\n'

        # Remove any options that should not be shown under "Settings"
        to_ignore = ['random_seed',
                     'user_column',
                     'item_column',
                     'target_column']
        for k in to_ignore:
            if k in options:
                del options[k]

        # Print remaining hyperparameters
        # TODO: Get max width of keys and use it for fixed width formatting.
        if len(options) > 0:
            ret += "\nSettings\n"
            for k, v in options.iteritems():
                ret += "  %-22s %-30s\n" % (k.replace('_', ' ') + ':', str(v))

        return ret

    def _get_default_options(self):
        """
        A dictionary describing all the parameters of the given model.
        For each parameter there may be:
          name
          description
          type (REAL, CATEGORICAL)
          default_value
          possible_values: for reals this includes a lower and upper bound.
                           for categoricals this is a list of possible values.
        """

        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run('recsys_get_default_options', opts)
        for k in response.keys():
            response[k] = _json.loads(response[k])
        return response

    def _get_current_options(self):
        """
        A dictionary describing all the parameters of the given model
        and their current setting.
        """
        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run('recsys_get_current_options', opts)
        return response

    def _set_current_options(self, options):
        """
        Set current options for a model.

        Parameters
        ----------
        options : dict
            A dictionary of the desired option settings. The key should be the name
            of the option and each value is the desired value of the option.
            The possible options are all those returne dy get_default_options().
        """
        opts = self._get_current_options()
        opts.update(options)
        opts['model'] = self.__proxy__
        response = _graphlab.toolkits.main.run('recsys_set_current_options', opts)
        return response

    def score(self, dataset):
        """
        Return a score prediction for the user ids and item ids in the provided
        data set.

        Parameters
        ----------
        dataset : SFrame
            Dataset in the same form used for training.

        Returns
        -------
        out : SArray
            An SArray with scores for each given user-item pair predicted by
            the model.
        """
        _mt._get_metric_tracker().track('toolkit.recsys.score')

        if type(dataset) == _pd.DataFrame:
            dataset = _SFrame(dataset)

        # Get metadata from C++ object
        opts = {'data_to_predict': dataset, 'model': self.__proxy__}

        # Call the C++ function for recommender_model
        response = _graphlab.toolkits.main.run('recsys_predict', opts)
        result = _SFrame(None, _proxy=response['data'])
        return result['prediction']

    def recommend(self, users=None, k=10, exclude=None, items=None,
                  new_observation_data=None, new_user_data=None, new_item_data=None,
                  exclude_known=True):
        """
        Recommends the k-highest scored items for each user in users.

        Parameters
        ----------
        users:  SArray or list; optional
            An :class:`~graphlab.SArray` (or list) of users for which to make recommendations.  If 'None', then recommend(...) generates predictions for all users in the training set.

        k: int, optional
            The number of recommendations to generate for each user.

        items: SArray, optional
            To restrict the set of items under consideration, the user may pass in an SArray of items from which the recommended items will be chosen.  This allows one to choose only items that are, for instance, in a particular genre or category.   The default is that all items are under consideration.

        new_observation_data: SFrame, optional
            ``new_observation_data`` may give additional observation data to the model.  Any overlap with the training data will be ignored.  Must be in the same format as the observation data passed to ``create``.

        new_user_data: SFrame, optional
            ``new_user_data`` may give additional user data to the model.  Any overlap with the training data will be ignored.  Must be in the same format as the user data passed to ``create``.

        new_item_data: SFrame, optional
            ``new_item_data`` may give additional item data to the model.  Any overlap with the training data will be ignored.  Must be in the same format as the item data passed to ``create``.

        exclude: SFrame, optional
            An :class:`~graphlab.SFrame` of user / item pairs.  The column names must be equal to the user and item columns of the main data, and it provides the model with user/item pairs to exclude from the recommendations.  These user-item-pairs are always excluded from the predictions, even if exclude_known is False.

        exclude_known: bool
            By default, all user-item interactions previously seen in the training data, or in any new data provided using new_observation_data=..., are excluded from the recommendations.  Passing in ``exclude_known = False`` overrides this behavior.

        Returns
        -------
        out : SFrame
            A SFrame with the top ranked items for each user. The columns are: ``user_column``, ``item_column``, *score*, and *rank*, where ``user_column`` and ``item_column`` match the user and item column names specified at training time.  The rank column is between 1 and ``k`` and gives the relative score of that item.  The value of score depends on the method used for recommendations.


        """
        _mt._get_metric_tracker().track('toolkit.recsys.recommend')
        assert type(k) == int

        if users is None:
            users = _SArray()
        if exclude is None:
            exclude = _SFrame()
        if items is None:
            items = _SArray()
        if new_observation_data is None:
            new_observation_data = _SFrame()
        if new_user_data is None:
            new_user_data = _SFrame()
        if new_item_data is None:
            new_item_data = _SFrame()

        if isinstance(users, list) or isinstance(users, _np.ndarray):
            users = _SArray(users)
        if isinstance(exclude, _pd.DataFrame) or isinstance(exclude, _np.ndarray):
            exclude = _SFrame(exclude)
        if isinstance(items, list) or isinstance(items, _np.ndarray):
            items = _SArray(items)
        if isinstance(new_observation_data, _pd.DataFrame) or isinstance(new_observation_data, _np.ndarray):
            new_observation_data = _SFrame(new_observation_data)
        if isinstance(new_user_data, _pd.DataFrame) or isinstance(new_user_data, _np.ndarray):
            new_user_data = _SFrame(new_user_data)
        if isinstance(new_item_data, _pd.DataFrame) or isinstance(new_item_data, _np.ndarray):
            new_item_data = _SFrame(new_item_data)

        def check_type(arg, arg_name, required_type, allowed_types):
            if not isinstance(arg, required_type):
                raise TypeError("Parameter " + arg_name + " must be of type(s) "
                                + (", ".join(allowed_types) )
                                + "; Type '" + str(type(arg)) + "' not recognized.")

        check_type(users, "users", _SArray, ["SArray", "list", "numpy.ndarray"])
        check_type(exclude, "exclude", _SFrame, ["SFrame"])
        check_type(items, "items", _SArray, ["SArray", "list", "numpy.ndarray"])
        check_type(new_observation_data, "new_observation_data", _SFrame, ["SFrame"])
        check_type(new_user_data, "new_user_data", _SFrame, ["SFrame"])
        check_type(new_item_data, "new_item_data", _SFrame, ["SFrame"])

        opt = {'model': self.__proxy__,
               'users': users,
               'top_k': k,
               'exclude': exclude,
               'items': items,
               'new_data': new_observation_data,
               'new_user_data': new_user_data,
               'new_item_data': new_item_data,
               'exclude_known': exclude_known*1}
        response = _graphlab.toolkits.main.run('recsys_recommend', opt)
        recs = _SFrame(None, _proxy=response['data'])
        return recs

    def training_stats(self):
        """
        Get information about model creation, e.g. time elapsed during
        model fitting, data loading, and more.

        Note: This method will be *deprecated* soon. Please use m.summary()
        instead.

        Returns
        -------
        out : dict
            Statistics about model training, e.g. runtime.

        """
        _logging.warning("This method will be deprecated soon. Please use m.summary().")
        _mt._get_metric_tracker().track('toolkit.recsys.training_stats')

        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run("recsys_get_stats", opts)
        return response

    def _evaluate_precision_recall(self, dataset, cutoffs=[5, 10, 15], skip_set=None):
        """
        Compute a model's precision and recall for a particular dataset.

        Parameters
        ----------
        dataset : SFrame
            An SFrame in the same format as the one used during training.
            This will be compared to the model's recommendations, which exclude
            the (user, item) pairs seen at training time.

        cutoffs : list
            A list of cutoff values for which one wants to evaluate precision
            and recall, i.e. the value of k in precision@k.

        Returns
        -------
        out : SFrame
            Contains the precision and recall at each cutoff value and each
            user in ``dataset``.
        """

        users = dataset[self.get('user_column')].unique()

        dataset = dataset[[self.get('user_column'), self.get('item_column')]]
        
        recs = self.recommend(users=users, k=max(cutoffs), exclude=skip_set)
        pr = _graphlab.recommender.util.precision_recall_by_user(dataset, recs, cutoffs)
        return pr

    def _evaluate_rmse(self, dataset):
        """
        Evaluate the prediction error for each user-item pair in the given data
        set.

        Parameters
        ----------
        dataset : SFrame/pandas.DataFrame
            A DataFrame in the same format as the one used during training.

        Returns
        -------
        out : dict
            A dictionary with three items: *rmse_by_user* and *rmse_by_item*, which
            are SFrames; and *rmse_overall*, which is a float.
        """

        if type(dataset) == _pd.DataFrame:
            dataset = _SFrame(dataset)

        assert 'target_column' in self.list_fields(), \
               'RMSE evaluation only valid for models trained using a target.'
        target_column = self.get('target_column')
        assert target_column in dataset.column_names(), \
               'Provided dataset must contain a target column with the same \
                name as the target used during training.'
        y = dataset[target_column]
        yhat = self.score(dataset)
        user_column = self.get('user_column')
        item_column = self.get('item_column')
        assert user_column in dataset.column_names() and \
               item_column in dataset.column_names(), \
            'Provided data set must have a column pertaining to user ids and \
             item ids, similar to what we had during training.'

        result = dataset[[user_column, item_column]]
        result['sq_error'] = (y - yhat).apply(lambda x: x**2)
        rmse_by_user = result.groupby(user_column,
                {'rmse':_graphlab.aggregate.AVG('sq_error'),
                 'count':_graphlab.aggregate.COUNT})
        rmse_by_user['rmse'] = rmse_by_user['rmse'].apply(lambda x: x**.5)
        rmse_by_item = result.groupby(item_column,
                {'rmse':_graphlab.aggregate.AVG('sq_error'),
                 'count':_graphlab.aggregate.COUNT})
        rmse_by_item['rmse'] = rmse_by_item['rmse'].apply(lambda x: x**.5)
        overall_rmse = result['sq_error'].mean() ** .5

        return {'rmse_by_user': rmse_by_user,
                'rmse_by_item': rmse_by_item,
                'rmse_overall': overall_rmse}

    def evaluate(self, dataset, verbose=False, **kwargs):
        r"""Evaluate the model's ability to make predictions for another
        dataset. If the model is trained to predict a particular target, the
        default metric used for model comparison is root-mean-squared error
        (RMSE). Suppose :math:`y` and :math:`\widehat{y}` are vectors of length
        :math:`N` with actual and predicted ratings. Then the RMSE is

        .. math::

            RMSE = \sqrt{\frac{1}{N} \sum_{i=1}^N (\widehat{y}_i - y_i)^2}

        If the model was not trained on a target column during, the default
        metrics for model comparison are precision and recall.

        Parameters
        ----------
        dataset : SFrame/pandas.DataFrame
            An SFrame or DataFrame that is in the same format as provided for
            training.

        verbose : bool, optional
            Enables verbose output. Default is verbose.

        kwargs : dict
            Parameters passed on to internal function
            `_evaluate_precision_recall`.

        Returns
        -------
        out : SFrame
            Results from  model evaluation procedure.

        Notes
        -----
            If the model is trained on a target (i.e. RMSE is the evaluation
            criterion), a dictionary with three items is returned: items
            *rmse_by_user* and *rmse_by_item* are SFrames with per-user and
            per-item RMSE, while *rmse_overall* is the overall RMSE (a float).
            If the model is trained without a target (i.e. precision and recall
            are the evaluation criteria) and :py:class:`~graphlab.SFrame` is
            returned with both of these metrics for each user at several cutoff
            values.
        """

        _mt._get_metric_tracker().track('toolkit.recsys.evaluate')

        # If the model does not have a target column, compute prec-recall.
        if self.get('target_column') is None or self.get('target_column') == '':
            results = self._evaluate_precision_recall(dataset, **kwargs)
            if verbose:
                print "\nPrecision and recall summary statistics by cutoff"
                print results.groupby('cutoff',
                        {'mean_precision':_graphlab.aggregate.AVG('precision'),
                         'mean_recall':_graphlab.aggregate.AVG('recall')}).topk('cutoff', reverse=True)
        else:
            results = self._evaluate_rmse(dataset)
            if verbose:
                print "\nOverall RMSE: ", results['rmse_overall']
                print "\nPer User RMSE (best)"
                print results['rmse_by_user'].topk('rmse', 1, reverse=True)
                print "\nPer User RMSE (worst)"
                print results['rmse_by_user'].topk('rmse', 1)
                print "\nPer Item RMSE (best)"
                print results['rmse_by_item'].topk('rmse', 1, reverse=True)
                print "\nPer Item RMSE (worst)"
                print results['rmse_by_item'].topk('rmse', 1)

        return results

