import graphlab as _graphlab
import numpy as np
import random as random
from pandas import DataFrame as _DataFrame
from graphlab.data_structures.sarray import SArray as _SArray
from graphlab.data_structures.sframe import SFrame as _SFrame
from graphlab.toolkits.recommender.recommender import RecommenderModel as _RecommenderModel

has_pyplot = False
try:
    import matplotlib.pyplot as pp
    from matplotlib import rcParams
    has_pyplot = True
except:
    pass


def create(dataset, user, item, target,
           holdout_probability=0.0, verbose=True, plot=False, **kwargs):
    """
    Trains the default GraphLab recommender system model. A trained model
    can be used to score (user, item) pairs and make recommendations.

    dataset pandas.DataFrame/SFrame:
    The dataset to use for training the model

    Parameters
    ----------
    dataset : pandas.DataFrame/SFrame
        The dataset to use for training the model.

    user : string
        The column name of the dataset that corresponds to user id.

    item : string
        The column name of the dataset that corresponds to item id.

    target : string
        The model will be trained to predict this column of the data set.

    holdout_probability : float, optional
        Proportion of the dataset used for estimating error rate on new/unseen
        data. This portion of the dataset will not be used for training.

    verbose : bool, optional
        Enables verbose output. Default is verbose.

    plot : bool, optional
        If true, display the progress plot.

    kwargs : dict, optional
        Arguments passed on to `matrix_factorization.create`.

    Returns
    -------
    out : MatrixFactorizationModel
        A trained matrix factorization model.
    """
    if not isinstance(dataset, (_DataFrame, _SFrame)):
        raise TypeError('dataset input must be a pandas.DataFrame or SFrame')

    return _graphlab.matrix_factorization.create(dataset, user, item, target,
                                                 holdout_probability=holdout_probability,
                                                 verbose=verbose, plot=plot, **kwargs)


def __plot_histogram(measurements, means, names=None, metric_name=None):
    """
    Plot histograms of the measurements, overlaid with vertical lines
    representing the means of the measurements.

    Parameters
    -------
    measurements : list
        List of measurements (recall, precision or RMSE).

    means : list
        List of doubles, intended to be the mean of each list in 'measurements'.

    names : list
        List of model name strings.

    metric_name : string
        Name of the metric.
    """
    num_measurements = len(measurements)

    # A list of colors for plotting
    COLORS_LIST = ['b', 'g', 'r', 'k', 'm', 'c']

    hist_handle = pp.hist(measurements, bins=20,
                          color=COLORS_LIST[:num_measurements],
                          label=names, hold=True)
    pp.legend()
    # hist() returns a list of two lists, the first is a list of all the counts,
    # the second is a list of the bin centers.  We need the maximum count so we know
    # how tall the vertical line should be.  So we take the max of the max of the
    # first list of lists
    max_count = max([max(hcount) for hcount in hist_handle[0]])
    pp.vlines(means, 0, max_count, colors=COLORS_LIST[:num_measurements])
    pp.xlabel(metric_name)
    pp.ylabel('Counts')


def __plot_overlap_hists(results, label, names, bins=20, alpha=0.3):
    """
    Plot overlapping (un-normalized) histograms for a list of one-dimensional
    series.

    Parameters
    -------
    results : list
        List of list-like objects. Each element is plotted as a separate histogram.

    label : string
        Label for the x-axis of the histogram.

    names : list
        Names for each series in `results'.

    bins : int
        Number of bins. Default is 20.

    alpha : float
        Opacity of the histogram patches. Default is 0.4.
    """

    fig, ax = pp.subplots()

    # plot first series to fix the bins
    counts, bins, patches = ax.hist(results[0], bins=bins, alpha=alpha, lw=0.1,
                         label=names[0])
    clr = patches[0].get_facecolor()
    counts = np.insert(counts, 0, 0)
    ax.step(bins, counts, color=clr, lw=5)

    # plot the rest of the series
    if len(results) > 1:
        for series, name in zip(results[1:], names[1:]):
            counts, bins, patches = ax.hist(series, bins=bins, alpha=alpha, lw=0.03,
                                            label=name, fill=True)
            clr = patches[0].get_facecolor()
            counts = np.insert(counts, 0, 0)
            ax.step(bins, counts, color=clr, lw=4)

    ax.ticklabel_format(style='sci', scilimits=(0, 0), axis='y')
    ax.set_xlabel(label)
    ax.set_ylabel('Frequency')
    ax.legend()
    fig.show()


def _compare_results_precision_recall(results, model_names=None):
    """
    Compare models that output precision/recall. Display the per-user
    precision/recall scatter plot, the histograms of precision, and the
    histograms of recall.

    Parameters
    -------
    results : list
        List of dataframes. Each dataframe describes the evaluation results for a
        separate model.

    model_names : list
        List of model name strings.
    """

    num_models = len(results)
    COLORS_LIST = ['b', 'g', 'r', 'k', 'm', 'c']

    if num_models < 1:
        return

    if model_names is None:
        model_names = ["model {}".format(i) for i in range(num_models)]

    pr_curves_by_model = [res.groupby('cutoff').mean() for res in results]
    fig, ax = pp.subplots()

    for i in range(num_models):
        pr_curve = pr_curves_by_model[i]
        name = 'Model ' + str(i + 1)
        if model_names != None:
            name = model_names[i]

        ax.plot(list(pr_curve['recall']), list(pr_curve['precision']),
                COLORS_LIST[i], label=name)

    ax.set_title('Precision-Recall Averaged Over Users')
    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    ax.legend()
    fig.show()


def _compare_results_rmse2(results, model_names=None):
    """
    Compare models that output root-mean-squared error (RMSE).

    Parameters
    -------
    results : list
        List of dataframes describing evaluation results for each model.

    model_names : list
        List of model name strings.
    """

    ## Validate the results
    num_models = len(results)

    if num_models < 1 or num_models > len(rcParams['axes.color_cycle']):
        return

    if model_names is None:
        model_names = ["model {}".format(i) for i in range(num_models)]

    ## Histograms of per-user and per-item RMSE
    user_rmse_by_model = [elem['rmse_by_user'].to_dataframe()['rmse'] for elem in results]
    __plot_overlap_hists(user_rmse_by_model, 'Per-User RMSE', model_names, bins=100)

    item_rmse_by_model = [elem['rmse_by_item'].to_dataframe()['rmse'] for elem in results]
    __plot_overlap_hists(item_rmse_by_model, 'Per-Item RMSE', model_names, bins=100)

    ## Bar chart of Overall RMSE by model
    overall_rmse_by_model = [elem['rmse_overall'] for elem in results]

    fig, ax = pp.subplots()
    BAR_WIDTH = 0.3
    centers = np.arange(num_models) + BAR_WIDTH
    ax.bar(centers, overall_rmse_by_model, BAR_WIDTH)
    ax.set_xticks(centers + BAR_WIDTH/2)
    ax.set_xticklabels(model_names)
    ax.set_title('Overall RMSE')
    fig.show()


def _compare_results_rmse(results, model_names=None):
    """
    Compare models that output RMSE. Display the histograms of per-user RMSE as
    well as a bar plot comparing the overall RMSE of each model.

    Parameters
    -------
    results : list
        List of DataFrames that describe evaluation results for each name in
        model_names.

    model_names : list
        List of model name strings.
    """

    num_models = len(results)

    if num_models < 1:
        return

    pp.ion()

    pp.figure()
    pp.subplot(2, 1, 1)
    user_rmse_by_model = [elem['rmse_by_user'].to_dataframe()['rmse'] for elem in results]
    user_rmse_means = [rmse.mean() for rmse in user_rmse_by_model]
    __plot_histogram(user_rmse_by_model, user_rmse_means, model_names, 'Per User RMSE')

    item_rmse_by_model = [elem['rmse_by_item'].to_dataframe()['rmse'] for elem in results]
    item_rmse_means = [rmse.mean() for rmse in item_rmse_by_model]
    pp.subplot(2, 1, 2)
    __plot_histogram(item_rmse_by_model, item_rmse_means, model_names, 'Per Item RMSE')

    pp.title('Histograms of Per User and Per Item RMSE')

    overall_rmse_by_model = [elem['rmse_overall'] for elem in results]
    fig, ax = pp.subplots()
    BAR_WIDTH = 0.3
    centers = np.arange(num_models) + BAR_WIDTH
    pp.bar(centers, overall_rmse_by_model, BAR_WIDTH)
    ax.set_xticks(centers + BAR_WIDTH/2)
    ax.set_xticklabels(model_names)
    pp.title('Overall RMSE')

    pp.show()


def compare_models(dataset, models, model_names=None, user_sample=1.0, **kwargs):
    """
    Compare models with respect to a common validation set.
    Models that are trained to predict ratings are compared separately from
    models that are trained without target ratings.  The ratings prediction
    models are compared on RMSE, and the rest are compared on precision-recall.

    Parameters
    -------
    dataset : pandas.DataFrame/SFrame
        Validation dataset.

    models : list
        List of trained RecommenderModels.

    model_names : list
        List of model name strings for display.

    user_sample : double
        Sampling proportion of unique users to use in estimating model
        performance. Defaults to 1.0, i.e. use all users in the dataset.

    Returns
    -------
    out : list of pandas.DataFrame/SFrame:
        A list of results where each one is an sframe of evaluation results of
        the respective model on the given dataset

    Examples
    --------
    If you have created two models ``m1`` and ``m2`` and have an :class:`~graphlab.SFrame` ``test_data``,
    then you may compare the performance of the two models on test data using:

    >>> graphlab.recommender.compare_models(test_data, [m1, m2])

    When evaluating recommender models you typically want to recommend items that a user has not previously viewed.
    In this case you can do

    >>> graphlab.recommender.compare_models(test_data, [m1, m2],
                                            skip_set=training_data)
    """

    num_models = len(models)

    if model_names is None:
        model_names = ['M' + str(i) for i in range(len(models))]

    if num_models < 1:
        raise ValueError, "Must pass in at least one recommender model to evaluate"

    if model_names is not None and len(model_names) != num_models:
        raise ValueError, "Must pass in the same number of model names as models"

    # if we are asked to sample the users, come up with a list of unique users
    if user_sample < 1.0:
        user_column_name = models[0].get('user_column')
        if user_column_name is None:
            raise ValueError, "user_column not set in model(s)"
        user_sa = dataset[user_column_name]
        unique_users = list(set(list(user_sa)))
        nusers = len(unique_users)
        ntake = int(round(user_sample * nusers))
        random.shuffle(unique_users)
        users = unique_users[:ntake]
        print "compare_models: using " + str(ntake) + \
              " users to estimate model performance"
        users = frozenset(users)
        ix = [u in users for u in dataset[user_column_name]]
        dataset_subset = dataset[_SArray(ix) == True]
    else:
        dataset_subset = dataset

    results = []
    for m in models:
        r = m.evaluate(dataset_subset, verbose=False,
                       cutoffs=range(2, 50, 2), **kwargs)
        results.append(r)

    if has_pyplot is True:
        # separate into models
        is_pr = [m.get('target_column') == '' for m in models]
        results_pr = [results[i].to_dataframe() for i in range(num_models) if is_pr[i]]
        results_rmse = [results[i] for i in range(num_models) if is_pr[i] is False]
        model_names_pr = None
        model_names_rmse = None
        if model_names is not None:
            model_names_pr = [model_names[i] for i in range(num_models) if is_pr[i]]
            model_names_rmse = [model_names[i] for i in range(num_models) if is_pr[i] is False]

        if len(results_pr) > 0:
            _compare_results_precision_recall(results_pr, model_names_pr)
        if len(results_rmse) > 0:
            _compare_results_rmse2(results_rmse, model_names_rmse)
        pp.show()
    else:
        "Warning: Matplotlib could not be imported - no plot output."

    return results


def rmse(true_values, predicted_values):
    """
    Compute the root-mean-squared error between two lists (or Numpy arrays).

    Parameters
    -------
    true_values : list-like
        Observed values

    predicted_values : list-like
        Predicted values.

    Returns
    -------
    out : double
        The root-mean-squared error.
    """
    assert len(true_values) == len(predicted_values)
    return np.mean((true_values - predicted_values) ** 2) ** 0.5
