"""
"""
import graphlab.connect as _mt
import graphlab as _graphlab
from graphlab.toolkits.model import Model
from graphlab.data_structures.sarray import SArray as _SArray
from graphlab.data_structures.sframe import SFrame as _SFrame
from graphlab.deps import pandas as _pandas, HAS_PANDAS as _HAS_PANDAS
from graphlab.deps import numpy as _numpy, HAS_NUMPY as _HAS_NUMPY
import json as _json
import logging as _logging

class RecommenderModel(Model):
    """
    The abstract class for GraphLab recommender system models. This class
    defines methods common to all recommender system models, but leaves unique
    model details to separate model classes.
    """

    def list_fields(self):
        """
        Get the current settings of the model. The keys depend on the type of
        model.

        Returns
        -------
        out : list
            A list of fields that can be queried using the ``get`` method.
        """

        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run('recsys_list_fields', opts)
        return [s for s in response['value'] if not s.startswith("_")]

    def get(self, field):
        """
        Get the value of a particular field.

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out
            The current value of the requested field.

        Examples
        --------
        >>> data = graphlab.SFrame({'user_id': ["0", "0", "0", "1", "1", "2", "2", "2"],
        ...                         'item_id': ["a", "b", "c", "a", "b", "b", "c", "d"],
        ...                         'rating': [1, 3, 2, 5, 4, 1, 4, 3]})
        >>> m = graphlab.recommender.create(data, "user_id", "item_id", "rating",
        ...                                 method = "matrix_factorization")
        >>> d = m.get("coefficients")
        >>> U1 = d['user_id']
        >>> U2 = d['movie_id']
        """

        opts = {'model': self.__proxy__, 'field': field}
        response = _graphlab.toolkits.main.run('recsys_get_value', opts)

        # Unpack the response

        result_type = response["__result_type__"]

        if result_type == "value":
            return response["value"]

        elif result_type == "dict":

            # Unpack the dictionary
            keys = response["__keys__"]

            def _translate_type(x):
                if isinstance(x, _graphlab.cython.cy_sframe.UnitySFrameProxy):
                    return _graphlab.SFrame(_proxy = x)
                else:
                    return x

            return dict( (k, _translate_type(response[k])) for k in keys)

        else:
            assert False, "Result type not specified."


    def summary(self):
        """
        Returns a summary of statistics about the model, the data
        used to train the model, and the training performance.

        Returns
        -------
        out : dict
            A dictionary containing model parameters, summary statistics of the
            data set, and summary statistics about training.

        """
        _mt._get_metric_tracker().track('toolkit.recsys.summary')
        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run('recsys_summary', opts)
        return response

    def __str__(self):
        """
        Returns the type of model.

        Returns
        -------
        out : string
            The type of model.
        """
        return self.__class__.__name__

    def __repr__(self):
        """
        Returns a string description of the model, including (where relevant)
        the schema of the training data, description of the training data,
        training statistics, and model hyperparameters.

        Returns
        -------
        out : string
            A description of the model.
        """

        stats = self.summary()
        options = self._get_current_options()

        # Print model type and data schema
        ret = self.__class__.__name__ + '\n'

        # If the model is untrained, there will not be an elapsed_time
        is_trained = 'elapsed_time' in stats.keys()

        # TODO: Refactor into functions
        if is_trained:

            ret += "\nSchema\n"
            for k in ["user_id", "item_id", "target"]:
                if k in options.keys():
                    if options[k] not in ['', 'not_specified'] and options[k] is not None:
                        ret += "  %-14s %s\n" % \
                            (k.replace('_', ' ') + ':', options[k])
                    del options[k]

            # Statistics about the training and validation set
            ret += "\nStatistics\n"
            ks = ["num_obs_train",
                  "num_unique_users",
                  "num_unique_items"]
            if all([k in stats.keys() for k in ks]):
                ret += "  %-17s %11d obs %11d users %11d items\n" % \
                    ("Training set:", stats[ks[0]], stats[ks[1]], stats[ks[2]])

            if len(stats['user_side_data_column_names']) > 0:
                ret += "  user side data:\n"
                colnames = stats['user_side_data_column_names']
                coltypes = stats['user_side_data_column_types']
                assert len(colnames) == len(coltypes), \
                    "Error importing user side data: mismatch between the \
                     number of column names and column types."

                for i in range(len(colnames)):
                    ret += '    {0} : {1}\n'.format(colnames[i], coltypes[i])

            if len(stats['item_side_data_column_names']) > 0:
                ret += "  item side data:\n"
                colnames = stats['item_side_data_column_names']
                coltypes = stats['item_side_data_column_types']
                assert len(colnames) == len(coltypes), \
                    "Error importing item side data: mismatch between the \
                     number of column names and column types."

                for i in range(len(colnames)):
                    ret += '    {0} : {1}\n'.format(colnames[i], coltypes[i])


            # Training summary
            ks = ["elapsed_time",
                  "random_seed",
                  "data_load_elapsed_time"]
            if any([k in stats.keys()]):
                ret += "\nTraining summary\n"
            if 'elapsed_time' in stats.keys():
                ret += "  %-20s %fs\n" % ("training time:", stats["elapsed_time"])
            if 'data_load_elapsed_time' in stats.keys():
                ret += "  %-20s %fs\n" % ("data load time:", stats["data_load_elapsed_time"])

            if "validation_metrics_elapsed_time" in stats.keys():
                ret += "  %-20s %fs\n" % \
                           ("metrics time:", stats["validation_metrics_elapsed_time"])

            # If available, print performance statistics
            for k in stats.keys():
                if 'rmse' in k or 'precision' in k or 'recall' in k:
                    ret += '  %-20s %-10s\n' % (k + ':', stats[k])

        else:
            ret += '\nThis model has yet to be trained.\n'

        # Remove any options that should not be shown under "Settings"
        to_ignore = ['random_seed',
                     'user_id',
                     'item_id',
                     'target']
        for k in to_ignore:
            if k in options:
                del options[k]

        # Print remaining hyperparameters
        # TODO: Get max width of keys and use it for fixed width formatting.
        if len(options) > 0:
            ret += "\nSettings\n"
            for k, v in options.iteritems():
                ret += "  %-22s %-30s\n" % (k.replace('_', ' ') + ':', str(v))

        return ret

    def _get_default_options(self):
        """
        A dictionary describing all the parameters of the given model.
        For each parameter there may be:
          name
          description
          type (REAL, CATEGORICAL)
          default_value
          possible_values: for reals this includes a lower and upper bound.
                           for categoricals this is a list of possible values.
        """

        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run('recsys_get_default_options', opts)
        for k in response.keys():
            response[k] = _json.loads(response[k])
        return response

    def _get_current_options(self):
        """
        A dictionary describing all the parameters of the given model
        and their current setting.
        """
        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run('recsys_get_current_options', opts)
        return response

    def _set_current_options(self, options):
        """
        Set current options for a model.

        Parameters
        ----------
        options : dict
            A dictionary of the desired option settings. The key should be the name
            of the option and each value is the desired value of the option.
            The possible options are all those returne dy get_default_options().
        """
        opts = self._get_current_options()
        opts.update(options)
        opts['model'] = self.__proxy__
        response = _graphlab.toolkits.main.run('recsys_set_current_options', opts)
        return response

    def predict(self, dataset,
                new_observation_data=None, new_user_data=None, new_item_data=None):
        """
        Return a score prediction for the user ids and item ids in the provided
        data set.

        Parameters
        ----------
        dataset : SFrame
            Dataset in the same form used for training.

        new_observation_data : SFrame, optional
            ``new_observation_data`` gives additional observation data
            to the model, which may be used by the models to improve
            score accuracy.  Must be in the same format as the
            observation data passed to ``create``.  How this data is
            used varies by model.

        new_user_data : SFrame, optional
            ``new_user_data`` may give additional user data to the
            model.  If present, scoring is done with reference to this
            new information.  If there is any overlap with the side
            information present at training time, then this new side
            data is preferred.  Must be in the same format as the user
            data passed to ``create``.

        new_item_data : SFrame, optional
            ``new_item_data`` may give additional item data to the
            model.  If present, scoring is done with reference to this
            new information.  If there is any overlap with the side
            information present at training time, then this new side
            data is preferred.  Must be in the same format as the item
            data passed to ``create``.

        Returns
        -------
        out : SArray
            An SArray with predicted scores for each given observation
            predicted by the model.

        See Also
        --------
        recommend, evaluate
        """
        _mt._get_metric_tracker().track('toolkit.recsys.predict')

        if _HAS_PANDAS and type(dataset) == _pandas.DataFrame:
            dataset = _SFrame(dataset)

        if new_observation_data is None:
            new_observation_data = _SFrame()
        if new_user_data is None:
            new_user_data = _SFrame()
        if new_item_data is None:
            new_item_data = _SFrame()

        if (_HAS_PANDAS and isinstance(new_observation_data, _pandas.DataFrame)) or (_HAS_NUMPY and isinstance(new_observation_data, _numpy.ndarray)):
            new_observation_data = _SFrame(new_observation_data)
        if (_HAS_PANDAS and isinstance(new_user_data, _pandas.DataFrame)) or (_HAS_NUMPY and isinstance(new_user_data, _numpy.ndarray)):
            new_user_data = _SFrame(new_user_data)
        if (_HAS_PANDAS and isinstance(new_item_data, _pandas.DataFrame)) or (_HAS_NUMPY and isinstance(new_item_data, _numpy.ndarray)):
            new_item_data = _SFrame(new_item_data)

        def check_type(arg, arg_name, required_type, allowed_types):
            if not isinstance(arg, required_type):
                raise TypeError("Parameter " + arg_name + " must be of type(s) "
                                + (", ".join(allowed_types))
                                + "; Type '" + str(type(arg)) + "' not recognized.")

        check_type(new_observation_data, "new_observation_data", _SFrame, ["SFrame"])
        check_type(new_user_data, "new_user_data", _SFrame, ["SFrame"])
        check_type(new_item_data, "new_item_data", _SFrame, ["SFrame"])

        # Get metadata from C++ object
        opts = {'data_to_predict': dataset,
                'model': self.__proxy__,
               'new_data': new_observation_data,
               'new_user_data': new_user_data,
               'new_item_data': new_item_data
                }

        # Call the C++ function for recommender_model
        response = _graphlab.toolkits.main.run('recsys_predict', opts)
        result = _SFrame(None, _proxy=response['data'])
        return result['prediction']

    def recommend(self, users=None, k=10, exclude=None, items=None,
                  new_observation_data=None, new_user_data=None, new_item_data=None,
                  exclude_known=True,
                  verbose=True):
        """
        Recommends the k-highest scored items for each user in users.

        Parameters
        ----------
        users :  SArray or list; optional
            An :class:`~graphlab.SArray` (or list) of users for which
            to make recommendations.  If 'None', then recommend(...)
            generates predictions for all users in the training set.
            The type of the SArray or list must be the same as the type 
            of the user_id column in the training data.

        k : int, optional
            The number of recommendations to generate for each user.

        items : SArray, optional
            To restrict the set of items under consideration, the user
            may pass in an SArray of items from which the recommended
            items will be chosen.  This allows one to choose only
            items that are, for instance, in a particular genre or
            category.  The default is that all items are under
            consideration.

        new_observation_data : SFrame, optional
            ``new_observation_data`` gives additional observation data
            to the model, which may be used by the models to improve
            score and recommendation accuracy.  Must be in the same
            format as the observation data passed to ``create``.  How
            this data is used varies by model.

        new_user_data : SFrame, optional
            ``new_user_data`` may give additional user data to the
            model.  If present, scoring is done with reference to this
            new information.  If there is any overlap with the side
            information present at training time, then this new side
            data is preferred.  Must be in the same format as the user
            data passed to ``create``.

        new_item_data : SFrame, optional
            ``new_item_data`` may give additional item data to the
            model.  If present, scoring is done with reference to this
            new information.  If there is any overlap with the side
            information present at training time, then this new side
            data is preferred.  Must be in the same format as the item
            data passed to ``create``.

        exclude : SFrame, optional
            An :class:`~graphlab.SFrame` of user / item pairs.  The
            column names must be equal to the user and item columns of
            the main data, and it provides the model with user/item
            pairs to exclude from the recommendations.  These
            user-item-pairs are always excluded from the predictions,
            even if exclude_known is False.

        exclude_known : bool, optional
            By default, all user-item interactions previously seen in
            the training data, or in any new data provided using
            new_observation_data=..., are excluded from the
            recommendations.  Passing in ``exclude_known = False``
            overrides this behavior.

        verbose : bool, optional
            If True, print the progress of generating recommendation.

        Returns
        -------
        out : SFrame
            A SFrame with the top ranked items for each user. The
            columns are: ``user_id``, ``item_id``, *score*,
            and *rank*, where ``user_id`` and ``item_id``
            match the user and item column names specified at training
            time.  The rank column is between 1 and ``k`` and gives
            the relative score of that item.  The value of score
            depends on the method used for recommendations.

        See Also
        --------
        predict
        evaluate
        """
        _mt._get_metric_tracker().track('toolkit.recsys.recommend')
        assert type(k) == int

        if users is None:
            users = _SArray()
        if exclude is None:
            exclude = _SFrame()
        if items is None:
            items = _SArray()
        if new_observation_data is None:
            new_observation_data = _SFrame()
        if new_user_data is None:
            new_user_data = _SFrame()
        if new_item_data is None:
            new_item_data = _SFrame()

        if isinstance(users, list) or (_HAS_NUMPY and isinstance(users, _numpy.ndarray)):
            users = _SArray(users)
        if (_HAS_PANDAS and isinstance(exclude, _pandas.DataFrame)) or (_HAS_NUMPY and isinstance(exclude, _numpy.ndarray)):
            exclude = _SFrame(exclude)
        if isinstance(items, list) or (_HAS_NUMPY and isinstance(items, _numpy.ndarray)):
            items = _SArray(items)
        if (_HAS_PANDAS and isinstance(new_observation_data, _pandas.DataFrame)) or (_HAS_NUMPY and isinstance(new_observation_data, _numpy.ndarray)):
            new_observation_data = _SFrame(new_observation_data)
        if (_HAS_PANDAS and isinstance(new_user_data, _pandas.DataFrame)) or (_HAS_NUMPY and isinstance(new_user_data, _numpy.ndarray)):
            new_user_data = _SFrame(new_user_data)
        if (_HAS_PANDAS and isinstance(new_item_data, _pandas.DataFrame)) or (_HAS_NUMPY and isinstance(new_item_data, _numpy.ndarray)):
            new_item_data = _SFrame(new_item_data)

        def check_type(arg, arg_name, required_type, allowed_types):
            if not isinstance(arg, required_type):
                raise TypeError("Parameter " + arg_name + " must be of type(s) "
                                + (", ".join(allowed_types))
                                + "; Type '" + str(type(arg)) + "' not recognized.")

        check_type(users, "users", _SArray, ["SArray", "list", "numpy.ndarray"])
        check_type(exclude, "exclude", _SFrame, ["SFrame"])
        check_type(items, "items", _SArray, ["SArray", "list", "numpy.ndarray"])
        check_type(new_observation_data, "new_observation_data", _SFrame, ["SFrame"])
        check_type(new_user_data, "new_user_data", _SFrame, ["SFrame"])
        check_type(new_item_data, "new_item_data", _SFrame, ["SFrame"])

        opt = {'model': self.__proxy__,
               'users': users,
               'top_k': k,
               'exclude': exclude,
               'items': items,
               'new_data': new_observation_data,
               'new_user_data': new_user_data,
               'new_item_data': new_item_data,
               'exclude_known': exclude_known}
        response = _graphlab.toolkits.main.run('recsys_recommend', opt, verbose=verbose)
        recs = _SFrame(None, _proxy=response['data'])
        return recs

    def training_stats(self):
        """
        Get information about model creation, e.g. time elapsed during
        model fitting, data loading, and more.

        Note: This method will be *deprecated* soon. Please use m.summary()
        instead.

        Returns
        -------
        out : dict
            Statistics about model training, e.g. runtime.

        """
        _logging.warning("This method will be deprecated soon. Please use m.summary().")
        _mt._get_metric_tracker().track('toolkit.recsys.training_stats')

        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run("recsys_get_stats", opts)
        return response

    def evaluate_precision_recall(self, dataset, cutoffs=[5, 10, 15], skip_set=None, exclude_known=True, verbose=True):
        """
        Compute a model's precision and recall scores for a particular dataset.

        Parameters
        ----------
        dataset : SFrame
            An SFrame in the same format as the one used during training.
            This will be compared to the model's recommendations, which exclude
            the (user, item) pairs seen at training time.

        cutoffs : list, optional
            A list of cutoff values for which one wants to evaluate precision
            and recall, i.e. the value of k in "precision at k".

        skip_set : SFrame, optional
            Passed to :meth:`recommend` as ``exclude``.

        exclude_known : bool, optional
            Passed to :meth:`recommend` as ``exclude_known``. If True, exclude
            training item from recommendation.

        verbose : bool, optional
            Enables verbose output. Default is verbose.

        Returns
        -------
        out : SFrame
            Contains the precision and recall at each cutoff value and each
            user in ``dataset``.

        See Also
        --------
        precision_recall_by_user
        """

        _mt._get_metric_tracker().track('toolkit.recsys.evaluate_precision_recall')

        users = dataset[self.get('user_id')].unique()

        dataset = dataset[[self.get('user_id'), self.get('item_id')]]

        recs = self.recommend(users=users, k=max(cutoffs), exclude=skip_set, exclude_known=exclude_known, verbose=verbose)
        pr = _graphlab.recommender.util.precision_recall_by_user(dataset, recs, cutoffs)
        return pr

    def evaluate_rmse(self, dataset, target):
        """
        Evaluate the prediction error for each user-item pair in the given data
        set.

        Parameters
        ----------
        dataset : SFrame
            An SFrame in the same format as the one used during training.

        target : str
            The name of the target rating column in `dataset`.

        Returns
        -------
        out : dict
            A dictionary with three items: 'rmse_by_user' and 'rmse_by_item',
            which are SFrames containing the average rmse for each user and
            item, respectively; and 'rmse_overall', which is a float.

        See Also
        --------
        graphlab.evaluation.rmse
        """

        _mt._get_metric_tracker().track('toolkit.recsys.evaluate_rmse')

        if _HAS_PANDAS and type(dataset) == _pandas.DataFrame:
            dataset = _SFrame(dataset)

        assert target in dataset.column_names(), \
               'Provided dataset must contain a target column with the same \
                name as the target used during training.'
        y = dataset[target]
        yhat = self.predict(dataset)
        user_column = self.get('user_id')
        item_column = self.get('item_id')
        assert user_column in dataset.column_names() and \
               item_column in dataset.column_names(), \
            'Provided data set must have a column pertaining to user ids and \
             item ids, similar to what we had during training.'

        result = dataset[[user_column, item_column]]
        result['sq_error'] = (y - yhat).apply(lambda x: x**2)
        rmse_by_user = result.groupby(user_column,
                {'rmse':_graphlab.aggregate.AVG('sq_error'),
                 'count':_graphlab.aggregate.COUNT})
        rmse_by_user['rmse'] = rmse_by_user['rmse'].apply(lambda x: x**.5)
        rmse_by_item = result.groupby(item_column,
                {'rmse':_graphlab.aggregate.AVG('sq_error'),
                 'count':_graphlab.aggregate.COUNT})
        rmse_by_item['rmse'] = rmse_by_item['rmse'].apply(lambda x: x**.5)
        overall_rmse = result['sq_error'].mean() ** .5

        return {'rmse_by_user': rmse_by_user,
                'rmse_by_item': rmse_by_item,
                'rmse_overall': overall_rmse}

    def evaluate(self, dataset, metric='auto',
                 exclude_known_for_precision_recall=True,
                 target=None,
                 verbose=False, **kwargs):
        r"""
        Evaluate the model's ability to make rating predictions or 
        recommendations. 

        If the model is trained to predict a particular target, the
        default metric used for model comparison is root-mean-squared error
        (RMSE). Suppose :math:`y` and :math:`\widehat{y}` are vectors of length
        :math:`N`, where :math:`y` contains the actual ratings and 
        :math:`\widehat{y}` the predicted ratings. Then the RMSE is defined as

        .. math::

            RMSE = \sqrt{\frac{1}{N} \sum_{i=1}^N (\widehat{y}_i - y_i)^2} .

        If the model was not trained on a target column, the default metrics for
        model comparison are precision and recall. Let
        :math:`p_k` be a vector of the :math:`k` highest ranked recommendations
        for a particular user, and let :math:`a` be the set of items for that 
        user in the groundtruth `dataset`. The "precision at cutoff k" is 
        defined as

        .. math:: P(k) = \frac{ | a \cap p_k | }{k}

        while "recall at cutoff k" is defined as

        .. math:: R(k) = \frac{ | a \cap p_k | }{|a|}

        Parameters
        ----------
        dataset : SFrame
            An SFrame that is in the same format as provided for training.

        metric : str, {'auto', 'rmse', 'precision_recall'}, optional
            Metric to use for evaluation. The default automatically chooses
            'rmse' for models trained with a `target`, and 'precision_recall'
            otherwise.

        exclude_known_for_precision_recall : bool, optional            
            A useful option for evaluating precision-recall. Recommender models
            have the option to exclude items seen in the training data from the
            final recommendation list. Set this option to True when evaluating
            on test data, and False when evaluating precision-recall on training
            data.

        target : str, optional        
            The name of the target column for evaluating rmse. If the model is
            trained with a target column, the default is to using the same
            column. If the model is trained without a target column and `metric`
            is set to 'rmse', this option must provided by user.

        verbose : bool, optional
            Enables verbose output. Default is verbose.

        **kwargs
            When `metric` is set to 'precision_recall', these parameters
            are passed on to :meth:`evaluate_precision_recall`.

        Returns
        -------
        out : SFrame or dict
            Results from the model evaluation procedure. If the model is trained
            on a target (i.e. RMSE is the evaluation criterion), a dictionary
            with three items is returned: items *rmse_by_user* and
            *rmse_by_item* are SFrames with per-user and per-item RMSE, while
            *rmse_overall* is the overall RMSE (a float). If the model is
            trained without a target (i.e. precision and recall are the
            evaluation criteria) an :py:class:`~graphlab.SFrame` is returned
            with both of these metrics for each user at several cutoff values.

        See Also
        --------
        evaluate_precision_recall, evaluate_rmse, precision_recall_by_user
        """

        _mt._get_metric_tracker().track('toolkit.recsys.evaluate')

        if metric is 'auto':
            if self.get('target') is None or self.get('target') == '':
                metric = 'precision_recall'
            else:
                metric = 'rmse'

        # If the model does not have a target column, compute prec-recall.
        if metric == 'precision_recall':
            results = self.evaluate_precision_recall(dataset,
                                                      exclude_known=exclude_known_for_precision_recall,
                                                      verbose=verbose,
                                                      **kwargs)
            if verbose:
                print "\nPrecision and recall summary statistics by cutoff"
                print results.groupby('cutoff', {'mean_precision': _graphlab.aggregate.AVG('precision'),
                                                 'mean_recall': _graphlab.aggregate.AVG('recall')}).topk('cutoff', reverse=True)
        elif metric == 'rmse':
            if target is None:
                target = self.get('target')
            if target is None or target == "":
                raise ValueError('Target column cannot be None. Please provide the target if the model \
                    is not trained with a target column')
            results = self.evaluate_rmse(dataset, target)
            if verbose:
                print "\nOverall RMSE: ", results['rmse_overall']
                print "\nPer User RMSE (best)"
                print results['rmse_by_user'].topk('rmse', 1, reverse=True)
                print "\nPer User RMSE (worst)"
                print results['rmse_by_user'].topk('rmse', 1)
                print "\nPer Item RMSE (best)"
                print results['rmse_by_item'].topk('rmse', 1, reverse=True)
                print "\nPer Item RMSE (worst)"
                print results['rmse_by_item'].topk('rmse', 1)
        else:
            raise ValueError('Unknown evaluation metric %s, supported metrics are [\"rmse\", \"precision_recall\"]' % metric)

        return results
