
from graphlab.data_structures.graph import Graph
import graphlab.toolkits.main as main
from graphlab.toolkits.model import Model
from collections import deque


class ShortestPathModel(Model):
    """
    A Model object containing the distance for each vertex in the graph to the
    source vertex.
    """
    def __init__(self, input_params, model):
        self.__proxy__ = model
        self._input = input_params
        self._fields = ['runtime', 'distance', 'graph']

    def list_fields(self):
        """Return the list of fields in the model."""

        return self._fields

    def get(self, field):
        """Return the value for the queried field."""

        if (field == 'distance'):
            return Graph(_proxy=self.__proxy__.get('__graph__')).select_fields(['distance']).get_vertices()
        elif (field == 'graph'):
            return Graph(_proxy=self.__proxy__.get('__graph__'))
        else:
            return self.__proxy__.get(field)

    def summary(self):
        """Return a summary of the model including hyperparameters, and runtime information."""

        ret = {'hyperparams': self._input}
        for key in ['runtime', 'distance']:
            ret[key] = self.get(key)
        return ret

    def get_path(self, vid, show=False, highlight=None, **kwargs):
        """
        Return one of the shortest paths between two vertices in a graph. The
        source vertex is specified by the original call to shortest path, and
        the destination vertex is specified in this function. Optionally, plots
        the path with networkx.

        Parameters
        ----------
        vid : string
            ID of the destination vertex. The source vertex ID is specified when the
            shortest path result is first computed.

        show : boolean
            Indicates whether the path should be plotted. Default is False.

        highlight : list
            If the path is plotted, identifies the vertices (by vertex ID) that
            should be highlighted by plotting in a different color.

        Returns
        -------
        path : list
            The vertex IDs on the path between the source and destination vertices.
        """

        path = deque()
        distance = self.get('distance').to_dataframe().set_index('__id')
        g = self.get('graph')
        if not vid in distance.index:
            raise ValueError('Vertex id ' + vid + ' not found')

        dist = distance.ix[vid][0]
        if dist > 1e5:
            raise ValueError('The distance to {} is too large to show the path.'.format(vid))

        current_id = vid
        weight_field = self._input['edge_attr']

        while dist > 0:
            path.appendleft((current_id, dist))
            edges = g.get_edges(dst_ids=[current_id]).to_dataframe()
            neighbor_ids = edges['__src_id']
            neighbor_dist = distance.ix[neighbor_ids]

            nearest_neighbors = None
            # for unit weights
            if weight_field == "":
                nearest_neighbors = neighbor_dist[neighbor_dist.distance < dist]
            else:
            # for non unit weights
                neighbor_cost = edges.set_index('__src_id')[weight_field]
                nearest_neighbors = neighbor_dist[neighbor_dist.distance + neighbor_cost == dist]

            assert(nearest_neighbors is not None and len(nearest_neighbors) > 0)
            current_id = nearest_neighbors.index[0]

            # This check is necessary because there could be duplicate edges
            # and pandas dataframe's ix method returns series object
            # for unique index and dataframe object for non-unique index.
            current_id_and_dist = neighbor_dist.ix[current_id]
            if len(current_id_and_dist) > 1:
                dist = current_id_and_dist['distance'].iloc[0]
            else:
                dist = current_id_and_dist[0]

        path.appendleft((current_id, dist))
        path = list(path)

        if show is True and len(path) > 1:
            sub_g = Graph()
            for i, j in zip(path, path[1:]):
                sub_g = sub_g.add_edges(g.get_edges(src_ids=i, dst_ids=j),
                                        src_field='__src_id',
                                        dst_field='__dst_id')

            path_highlight = []

            if highlight is not None:
                path_names = set([x[0] for x in path])
                path_highlight = list(set.intersection(path_names, set(highlight)))

            sub_g.show(vlabel='id', highlight=path_highlight, **kwargs)

        return path

    def _get_wrapper(self):
        return lambda m: ShortestPathModel(self._input, m)


def create(graph, source_vid, weight_field="", verbose=False, plot=False):
    """
    Compute the single source shortest path distance from the source vertex to
    all vertices in the graph. Note that because Graph is directed, shortest
    paths are also directed. To find undirected shortes paths add edges to the
    Graph in both directions. Return a model object with distance each of vertex in the
    graph.

    Parameters
    ----------
    graph : Graph
        The graph on which to compute shortest paths.

    source_vid : vertex ID
        ID of the source vertex.

    weight_field : string, optional
        The edge field representing the edge weights. If empty, uses unit
        weights.

    verbose : bool, optional
        If True, print progress updates.

    plot : bool, optional
        If True, display the progress plot.

    Returns
    -------
    out : Model
        A Model object that contains the shortest path distance for each vertex
        to the source vertex.

    Examples
    --------
    If given an :class:`~graphlab.Graph` ``g``, we can create
    a :class:`~graphlab.shortest_path.ShortestPathModel` as follows:

    >>> sp = shortest_path.create(g, source_vid=123)

    We can obtain the shortest path from the source vertex to each vertex in
    the graph ``g`` as follows:

    >>> sp_sframe = sp.get('distance')   # SFrame

    We can obtain an auxiliary graph with additional information corresponding
    to the shortest path from the source vertex to each vertex in the graph
    ``g`` as follows:

    >>> sp_graph = sp.get('graph')       # Graphlab graph
    """
    if not isinstance(graph, Graph):
        raise TypeError('graph input must be a Graph object.')

    if plot is True:
        print "The plot functionality for shortest path is not yet implemented."
        plot = False

    if verbose is True:
        print "Starting shortest path computation."
        verbose = 0

    opts = {'source_vid': source_vid, 'edge_attr': weight_field,
            'graph': graph.__proxy__}
    params = main.run('sssp', opts, verbose, plot)
    return ShortestPathModel({'source_vid': source_vid,
                             'edge_attr': weight_field}, params['model'])
