
from graphlab.data_structures.sgraph import SGraph
import graphlab.toolkits.main as main
from graphlab.toolkits.model import Model


class GraphColoringModel(Model):
    """
    A Model object containing a color ID for each vertex and the total number
    of colors used.

    An instance of this model can be created using :func:`graphlab.graph_coloring.create`.
    Do NOT construct the model directly.
    """
    def __init__(self, input_params, model):
        '''__init__(self)'''
        self.__proxy__ = model
        self._input = input_params
        self._fields = ['num_colors', 'runtime', 'colorid', 'graph']

    def list_fields(self):
        """
        List of fields stored in the model. Each of these fields can be queried
        using the ``get`` function.

        Returns
        -------
        out : list
            A list of fields that can be queried using the ``get`` method.
        """

        return self._fields

    def get(self, field):
        """
        Get the value of a given field. The list of all queryable fields is
        detailed below, and can be obtained programmatically using the
        :func:`~graphlab.graph_analytics.GraphColoringModel.list_fields`
        method.

        Each of these fields can be queried in one of two ways:

        >>> out = m['field']
        >>> out = m.get('field')  # equivalent to previous line

        +-----------------+----------------------------------------------------+
        |      Field      | Description                                        |
        +=================+====================================================+
        | colorid         | SFrame with the color ID of each vertex            |
        +-----------------+----------------------------------------------------+
        | graph           | Input graph with color ID as a vertex property     |
        +-----------------+----------------------------------------------------+

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out : [various]
            The current value of the requested field.

       """

        if (field == 'colorid'):
            return SGraph(_proxy=self.__proxy__.get('__graph__')).select_fields(['colorid']).get_vertices()
        elif (field == 'graph'):
            return SGraph(_proxy=self.__proxy__.get('__graph__'))
        else:
            return self.__proxy__.get(field)

    def summary(self):
        """Return a summary of the model including hyperparameters, and runtime information."""

        ret = {'hyperparams': self._input}
        for key in ['runtime', 'num_colors']:
            ret[key] = self.get(key)
        return ret

    def _get_wrapper(self):
        return lambda m: GraphColoringModel(self._input, m)


def create(graph, verbose=True, plot=False):
    """
    Compute the graph coloring. Assign a color to each vertex such that no
    adjacent vertices have the same color. Return a model object with total
    number of colors used as well as the color ID for each vertex in the graph.
    This algorithm is greedy and is not guaranteed to find the **minimum** graph
    coloring. It is also not deterministic, so successive runs may return
    different answers.

    Parameters
    ----------
    graph : SGraph
        The graph on which to compute the coloring.

    verbose : bool, optional
        If True, print progress updates.

    plot : bool, optional
        If True, display the progress plot.

    Returns
    -------
    out : Model
        Model object that contains the color id for each vertex.

    References
    ----------
    - `Wikipedia - graph coloring <http://en.wikipedia.org/wiki/Graph_coloring>`_

    Examples
    --------
    If given an :class:`~graphlab.SGraph` ``g``, we can create
    a :class:`~graphlab.graph_coloring.GraphColoringModel` as follows:

    >>> gc = graph_coloring.create(g)

    We can obtain the ``color id`` corresponding to each vertex in the graph ``g``
    as follows:

    >>> color_id = gc['colorid']  # SFrame

    We can obtain the total number of colors required to color the graph ``g``
    as follows:

    >>> num_colors = gc['num_colors']
    """
    if not isinstance(graph, SGraph):
        raise TypeError('graph input must be a SGraph object.')

    if plot is True:
        print "The plot functionality for graph coloring is not yet implemented."

    params = main.run('graph_coloring', {'graph': graph.__proxy__}, verbose,
                      plot)

    return GraphColoringModel({}, params['model'])
