"""
This module contains the K-Means++ clustering algorithm, including the
KmeansModel class which provides methods for inspecting the returned cluster
information.
"""

import graphlab.toolkits.main as main
from graphlab.toolkits.model import Model
from graphlab.data_structures.sframe import SFrame
from graphlab.deps import pandas as _pandas, HAS_PANDAS as _HAS_PANDAS

_has_matplotlib = False
try:
    # import matplotlib as mpl
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from matplotlib import rcParams
    _has_matplotlib = True
except:
    pass

DEFAULT_HYPER_PARAMETER_RANGE = {
    'num_clusters': range(2, 21)
}


class KmeansModel(Model):
    """
    A k-means model object containing a cluster id for each vertex, and the
    centers of the clusters.

    An instance of this model can be created using
    :func:`graphlab.kmeans.create`. Do NOT construct the model directly.

    Given a number of clusters, k-means++ iteratively chooses the best cluster
    centers and assigns nearby points to the best cluster. If no points change
    cluster membership between iterations, the algorithm terminates. The
    GraphLab k-means toolkit returns a KmeansModel, which contains a
    **clusterid** and a **cluster_info** DataFrame. These objects contain the
    location of the cluster centers and other cluster statistics.
    """
    def __init__(self, input_params, model):
        '''__init__(self)'''
        self.__proxy__ = model
        self._input = input_params
        self._fields = ['clusterid', 'runtime', 'num_clusters', 'num_iterations', 'cluster_info']

    def get(self, field):
        """
        Return the value of a given field.

        The list of all queryable fields is detailed below, and can be obtained
        with the ``list_fields`` method.

        +-----------------------+----------------------------------------------+
        |      Field            | Description                                  |
        +=======================+==============================================+
        | clusterid             | Cluster assignment for each data point       |
        +-----------------------+----------------------------------------------+
        | runtime               | Total time taken to cluster the data         |
        +-----------------------+----------------------------------------------+
        | num_clusters          | Number of clusters to use                    |
        +-----------------------+----------------------------------------------+
        | num_iterations        | Total number of iterations performed         |
        +-----------------------+----------------------------------------------+
        | cluster_info          | Cluster centers                              |
        +-----------------------+----------------------------------------------+

        Parameters
        ----------
        field : str
            The name of the field to query.

        Returns
        -------
        out
            Value of the requested field

        See Also
        --------
        list_fields

        Examples
        --------
        
        >>> model.get("cluster_info")
        d1        d2        d3        d4  __within_distance__  __size__
        0 -0.777484  1.048897  0.523926  0.487775             2.459470         4
        1  0.844906 -0.613151 -0.088785 -0.212908             3.651614         5
        2 -1.114592 -1.129836 -1.651781 -0.886557             0.000000         1

        [3 rows x 6 columns]
        """

        return self.__proxy__.get(field)

    def list_fields(self):
        """
        List the fields stored in the model, including the number of iterations
        performed, total runtime for the clustering algorithm, and cluster data.

        Each field can be queried with the ``get`` method.

        Returns
        -------
        out : dict
            Dictionary with the default options.

        See Also
        --------
        get

        Examples
        --------

        >>> model.list_fields()
        ['clusterid', 'runtime', 'num_clusters', 'num_iterations', 'cluster_info']
        """
        return self._fields

    def summary(self):
        """
        Display a summary of the model including hyperparameters, and runtime
        information.

        See Also
        --------
        get, list_fields

        Examples
        --------

        >>> model.summary()
        {'hyperparams': {'max_iter': 10, 'num_clusters': 3},
        'num_iterations': 1,
        'runtime': 0.024857}
        
        Use the ``get`` method to retrieve these values programmatically, or to
        see more detail about the queryable fields.
        """

        ret = {'hyperparams': self._input}
        for key in ['runtime', 'num_iterations']:
            ret[key] = self.get(key)
        print ret

    def _get_wrapper(self):
        return lambda m: KmeansModel(self._input, m)

    def __str__(self):
        """
        Return a string description of the model to the ``print`` method.

        Returns
        -------
        out : string
            A description of the KMeansModel.
        """
        return self.__repr__()

    def __repr__(self):
        """
        Print a string description of the model when the model name is entered
        in the terminal.
        """

        width = 24
        key_str = "{:<{}}: {}"

        model_fields = [
            ('Cluster assignments', 'clusterid'),
            ('Cluster centers', 'cluster_info'),
            ('Total runtime (seconds)', 'runtime'),
            ('Number of clusters', 'num_clusters'),
            ('Number of iterations', 'num_iterations')]

        out = []
        out.append(key_str.format("Class", width, self.__class__.__name__))

        for k, v in model_fields:
            value = self.get(v)
            if isinstance(value, float):
                try:
                    value = round(value, 4)
                except:
                    pass
            if isinstance(value, _pandas.core.frame.DataFrame):
                value = "DataFrame with shape %s" % str(value.shape)
            out.append(key_str.format(k, width, value))

        return '\n'.join(out)


def create(dataset, num_clusters, max_iterations=10, verbose=True, plot=False):
    """
    Run k-means++, computing the cluster centers and the cluster assignment for
    each data point in the dataset.

    Parameters
    -----------
    dataset : SFrame
        Each row in the SFrame is an observation.

    num_clusters : int
        Number of clusters (K).

    max_iterations : int, optional
        The maximum number of iterations to run. Prints a warning if the
        algorithm does not converge after max_iterations iterations.

    verbose : bool, optional
        If True (default), print model training progress to the screen.

    plot : bool, optional
        If True, display the progress plot.

    Returns
    -------
    out : KmeansModel
        A Model object containing a cluster id for each vertex, and the centers
        of the clusters.

    References
    ----------
    - `Wikipedia - k-means clustering
      <http://en.wikipedia.org/wiki/K-means_clustering>`_
    - Artuhur, D. and Vassilvitskii, S. (2007) `k-means++: The Advantages of
      Careful Seeding <http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf>`_. In
      Proceedings of the Eighteenth Annual ACM-SIAM Symposium on Discrete
      Algorithms. pp. 1027-1035.

    Examples
    --------

    >>> sf = graphlab.SFrame({
        "d1": [ 0.46973508, 0.0063261, 0.14143399, 0.35025834,
                0.83728709, 0.81438336, 0.74205833, 0.36273747,
                0.00793858, 0.02298716],
        "d2": [ 0.51050977, 0.82167952, 0.61451765, 0.51179513,
                0.35223035, 0.59366481, 0.48848649, 0.90026032,
                0.78798728, 0.40125452],
        "d3": [ 0.71716265, 0.54163387, 0.55577274, 0.12619953,
                0.80172228, 0.21519973, 0.21014113, 0.54207596,
                0.65432528, 0.00754797],
        "d4": [ 0.69486673, 0.92585721, 0.95461882, 0.72658554,
                0.86590678, 0.18017175, 0.60361348, 0.89223113,
                0.37992791, 0.44700959] 
        })

    It's important to standardize our columns to get the best results
    possible from the k-means algorithm.

    >>> for col in ['d1', 'd2', 'd3', 'd4']:
            sf[col] = (sf[col] - sf[col].mean()) / sf[col].std()
    >>> model = graphlab.kmeans.create(sf, num_clusters=3)
    """

    opts = {'num_clusters': num_clusters,
            'max_iterations': max_iterations}
    if (_HAS_PANDAS and isinstance(dataset, _pandas.DataFrame)):
        opts['dataframe'] = dataset
    elif (isinstance(dataset, SFrame)):
        opts['sframe'] = dataset
    else:
        raise TypeError("dataset must be a graph, dataframe or sframe")
    params = main.run('kmeans', opts, verbose, plot)
    
    return KmeansModel({'num_clusters': num_clusters, 'max_iterations': max_iterations},
                       params['model'])
