"""
This module contains the K-Means++ clustering algorithm, including the
KmeansModel class which provides methods for inspecting the returned cluster
information.
"""

import graphlab.toolkits.main as main
from graphlab.toolkits.model import Model
from graphlab.data_structures.sframe import SFrame
from graphlab.data_structures.sarray import SArray
from pandas import DataFrame, unique
import math

_has_matplotlib = False
try:
    # import matplotlib as mpl
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from matplotlib import rcParams
    _has_matplotlib = True
except:
    pass


class KmeansModel(Model):
    """
    A model object containing a cluster id for each vertex, and the centers
    of the clusters.

    An instance of this model can be created using :func:`graphlab.kmeans.create`.
    Do NOT construct the model directly.
    """
    def __init__(self, input_params, model):
        '''__init__(self)'''
        self.__proxy__ = model
        self._input = input_params
        self._fields = ['clusterid', 'runtime', 'num_iterations', 'cluster_info']

    def get(self, field):
        """Return the value for the queried field."""

        return self.__proxy__.get(field)

    def list_fields(self):
        """Return the list of fields in the model."""

        return self._fields

    def summary(self):
        """Return a summary of the model including hyperparameters, and runtime
        information.
        """

        ret = {'hyperparams': self._input}
        for key in ['runtime', 'num_iterations']:
            ret[key] = self.get(key)
        return ret

    def _get_wrapper(self):
        return lambda m: KmeansModel(self._input, m)


def create(data, num_clusters, max_iter=10, verbose=True, plot=False):
    """
    Run KMeans++ on the data, compute the cluster centers and the cluster
    assignment for each data point.

    Parameters
    -----------
    data : pandas.DataFrame.
        Each row in the dataframe is a observation.

    num_clusters : int
        Number of clusters (K).

    max_iter : int, optional
        The maximum number of iterations to run. Prints a warning if the
        algorithm does not converge after max_iter iterations.

    verbose : bool, optional
        If True (default), print model training progress to the screen.

    plot : bool, optional
        If True, display the progress plot.

    Returns
    -------
    out : KmeansModel
        A Model object containing a cluster id for each vertex, and the centers
        of the clusters.

    References
    ----------
    - Artuhur, D. and Vassilvitskii, S. (2007) `k-means++: The Advantages of
      Careful Seeding <http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf>`_. In
      Proceedings of the Eighteenth Annual ACM-SIAM Symposium on Discrete
      Algorithms. pp. 1027-1035.
    """

    opts = {'num_clusters': num_clusters,
            'max_iter': max_iter}
    if (isinstance(data, DataFrame)):
        opts['dataframe'] = data
    elif (isinstance(data, SFrame)):
        opts['sframe'] = data
    else:
        raise TypeError("data must be a graph, dataframe or sframe")
    params = main.run('kmeans', opts, verbose, plot)
    return KmeansModel({'num_clusters': num_clusters, 'max_iter': max_iter},
                       params['model'])
