from graphlab.data_structures.sgraph import SGraph as _SGraph
import graphlab.toolkits.main as _main
from graphlab.toolkits.graph_analytics.model_base import GraphAnalyticsModel as _ModelBase
import copy

_HAS_IPYTHON = True
try:
    import IPython.core.display
except:
    _HAS_IPYTHON = False


def get_default_options():
    """
    Get the default options for :func:`graphlab.shortest_path.create`.

    Returns
    -------
    out : dict

    Examples
    --------
    >>> graphlab.shortest_path.get_default_options()
    """
    return _main.run('sssp_default_options', {})


class ShortestPathModel(_ModelBase):
    """
    Model object containing the distance for each vertex in the graph to a
    single source vertex, which is specified during
    :func:`graphlab.shortest_path.create`.

    The model also allows querying for one of the shortest paths from the source
    vertex to any other vertex in the graph.

    Below is a list of queryable fields for this model:

    +----------------+------------------------------------------------------------+
    | Field          | Description                                                |
    +================+============================================================+
    | graph          | A new SGraph with the triangle count as a vertex property. |
    +----------------+------------------------------------------------------------+
    | num_triangles  | Total number of triangles in the graph.                    |
    +----------------+------------------------------------------------------------+
    | triangle_count | An SFrame with the triangle count for each vertex.         |
    +----------------+------------------------------------------------------------+
    | runtime        | Total runtime of the toolkit                               |
    +----------------+------------------------------------------------------------+

    An instance of this model can be created using
    :func:`graphlab.shortest_path.create`.  Do NOT construct the model
    directly.

    See Also
    --------
    create
    """
    def __init__(self, model):
        '''__init__(self)'''
        self.__proxy__ = model
        self._path_query_table = None

    def _result_fields(self):
        ret = super(ShortestPathModel, self)._result_fields()
        ret['distance'] = "SFrame with each vertex's distance. See m['distance']"
        return ret

    def _setting_fields(self):
        ret = super(ShortestPathModel, self)._setting_fields()
        for k in ['source_vid', 'weight_field', 'max_distance']:
            ret[k] = self[k]
        return ret

    def _method_fields(self):
        return {'get_path': 'Get shortest path, e.g. m.get_path(vid=target_vid)'}

    def get_path(self, vid, show=False, highlight=None, **kwargs):
        """
        Get the shortest path.
        Return one of the shortest paths between the source vertex defined
        in the model and the query vertex.
        The source vertex is specified by the original call to shortest path.
        Optionally, plots the path with networkx.

        Parameters
        ----------
        vid : string
            ID of the destination vertex. The source vertex ID is specified
            when the shortest path result is first computed.

        show : boolean
            Indicates whether the path should be plotted. Default is False.

        highlight : list
            If the path is plotted, identifies the vertices (by vertex ID) that
            should be highlighted by plotting in a different color.

        kwargs :
            Additional parameters passed into the :func:`graphlab.SGraph.show`
            when `show` is True.

        Returns
        -------
        path : list
            List of pairs of (vertex_id, distance) in the path.

        Examples
        --------
        >>> m.get_path(vid=0, show=True)

        See Also
        --------
        SGraph.show
        """
        if self._path_query_table is None:
            self._path_query_table = self._generate_path_sframe()

        source_vid = self['source_vid']
        path = []
        path_query_table = self._path_query_table
        if not vid in path_query_table['vid']:
            raise ValueError('Destination vertex id ' + str(vid) + ' not found')

        record = path_query_table[path_query_table['vid'] == vid][0]
        dist = record['distance']
        if dist > 1e5:
            raise ValueError('The distance to {} is too large to show the path.'.format(vid))

        path = [(vid, dist)]
        max_iter = len(path_query_table)
        num_iter = 0
        while record['distance'] != 0 and num_iter < max_iter:
            parent_id = record['parent_row_id']
            assert parent_id < len(path_query_table)
            assert parent_id >= 0
            record = path_query_table[parent_id]
            path.append((record['vid'], record['distance']))
            num_iter += 1
        assert record['vid'] == source_vid
        assert num_iter < max_iter
        path.reverse()

        if show is True and len(path) > 1:
            sub_g = _SGraph()
            for i, j in zip(path, path[1:]):
                sub_g = sub_g.add_edges(self['graph'].get_edges(src_ids=i, dst_ids=j),
                                        src_field='__src_id',
                                        dst_field='__dst_id')

            path_highlight = []

            if highlight is not None:

                if not isinstance(highlight, list):
                    raise TypeError("Input 'highlight' must be a list.")

                path_names = set([x[0] for x in path])
                path_highlight = list(set.intersection(path_names, set(highlight)))

                plot = sub_g.show(vlabel='id', highlight=path_highlight, **kwargs)
                if _HAS_IPYTHON:
                    IPython.core.display.display(plot)

        return path

    def _generate_path_sframe(self):
        """
        Generates an sframe with columns: vid, parent_row_id, and distance.
        Used for speed up the path query.
        """
        source_vid = self['source_vid']
        weight_field = self['weight_field']

        query_table = copy.copy(self.get('distance'))
        query_table.add_row_number('row_id')

        g = self.get('graph').add_vertices(query_table)
        # The sequence id which a vertex is visited, initialized with 0 meaning not visited.
        g.vertices['__parent__'] = -1
        weight_field = self['weight_field']
        if (weight_field == ""):
            weight_field = '__unit_weight__'
            g.edges[weight_field] = 1

        # Traverse the graph once and get the parent row id for each vertex
        def traverse_fun(src, edge, dst):
            if src['__id'] == source_vid:
                src['__parent__'] = src['row_id']
            if dst['distance'] == src['distance'] + edge[weight_field]:
                dst['__parent__'] = max(dst['__parent__'], src['row_id'])
            return (src, edge, dst)

        g = g.triple_apply(traverse_fun, ['__parent__'])
        query_table = query_table.join(g.get_vertices()[['__id', '__parent__']], '__id').sort('row_id')
        query_table.rename({'__parent__': 'parent_row_id', '__id': 'vid'})
        return query_table


def create(graph, source_vid, weight_field="", max_distance=1e30, verbose=True, plot=False):
    """
    Compute the single source shortest path distance from the source vertex to
    all vertices in the graph. Note that because SGraph is directed, shortest
    paths are also directed. To find undirected shortes paths add edges to the
    SGraph in both directions. Return a model object with distance each of
    vertex in the graph.

    Parameters
    ----------
    graph : SGraph
        The graph on which to compute shortest paths.

    source_vid : vertex ID
        ID of the source vertex.

    weight_field : string, optional
        The edge field representing the edge weights. If empty, uses unit
        weights.

    verbose : bool, optional
        If True, print progress updates.

    plot : bool, optional
        If True, display the progress plot.

    Returns
    -------
    out : ShortestPathModel

    References
    ----------
    - `Wikipedia - ShortestPath <http://en.wikipedia.org/wiki/Shortest_path_problem>`_

    Examples
    --------
    If given an :class:`~graphlab.SGraph` ``g``, we can create
    a :class:`~graphlab.shortest_path.ShortestPathModel` as follows:

    >>> g = graphlab.load_graph('http://snap.stanford.edu/data/email-Enron.txt.gz', format='snap')
    >>> sp = graphlab.shortest_path.create(g, source_vid=1)

    We can obtain the shortest path distance from the source vertex to each
    vertex in the graph ``g`` as follows:

    >>> sp_sframe = sp['distance']   # SFrame

    To get the actual path from the source vertex to any destination vertex:

    >>> path = sp.get_path(vid=10)


    We can obtain an auxiliary graph with additional information corresponding
    to the shortest path from the source vertex to each vertex in the graph
    ``g`` as follows:

    >>> sp_graph = sp.get['graph']       # SGraph

    See Also
    --------
    ShortestPathModel
    """
    if not isinstance(graph, _SGraph):
        raise TypeError('graph input must be a SGraph object.')

    if plot is True:
        print "The plot functionality for shortest path is not yet implemented."

    opts = {'source_vid': source_vid, 'weight_field': weight_field,
            'max_distance': max_distance, 'graph': graph.__proxy__}
    params = _main.run('sssp', opts, verbose, plot)
    return ShortestPathModel(params['model'])
