from graphlab.data_structures.sgraph import SGraph as _SGraph
import graphlab.aggregate as _Aggregate
import graphlab.toolkits.main as main
from graphlab.toolkits.model import Model as _Model
import copy

_HAS_IPYTHON = True
try:
    import IPython.core.display
except:
    _HAS_IPYTHON = False


class ShortestPathModel(_Model):
    """
    A model object containing the distance for each vertex in the graph to the
    source vertex.

    An instance of this model can be created using :func:`graphlab.shortest_path.create`.
    Do NOT construct the model directly.
    """
    def __init__(self, input_params, model):
        '''__init__(self)'''
        self.__proxy__ = model
        self._input = input_params
        self._fields = ['runtime', 'distance', 'graph']
        self._path_query_table = None

    def list_fields(self):
        """
        List of fields stored in the model. Each of these fields can be queried
        using the ``get`` function.

        Returns
        -------
        out : list
            A list of fields that can be queried using the ``get`` method.
        """

        return self._fields

    def get(self, field):
        """
        Get the value of a given field. The list of all queryable fields is
        detailed below, and can be obtained programmatically using the
        :func:`~graphlab.graph_analytics.ShortestPathModel.list_fields`
        method.

        Each of these fields can be queried in one of two ways:

        >>> out = m['field']
        >>> out = m.get('field')  # equivalent to previous line

        +-----------+----------------------------------------------------------+
        |   Field   | Description                                              |
        +===========+==========================================================+
        | distance  | SFrame with each vertex's distance to the queried vertex |
        +-----------+----------------------------------------------------------+
        | graph     | Input graph with the distance as a vertex property       |
        +-----------+----------------------------------------------------------+

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out : [various]
            The current value of the requested field.

       """

        if (field == 'distance'):
            return _SGraph(_proxy=self.__proxy__.get('__graph__')).select_fields(['distance']).get_vertices()
        elif (field == 'graph'):
            return _SGraph(_proxy=self.__proxy__.get('__graph__'))
        else:
            return self.__proxy__.get(field)

    def summary(self):
        """Return a summary of the model including hyperparameters, and runtime information."""

        ret = {'hyperparams': self._input}
        for key in ['runtime', 'distance']:
            ret[key] = self.get(key)
        return ret

    def get_path(self, vid, show=False, highlight=None, **kwargs):
        """
        Return one of the shortest paths between two vertices in a graph. The
        source vertex is specified by the original call to shortest path, and
        the destination vertex is specified in this function. Optionally, plots
        the path with networkx.

        Parameters
        ----------
        vid : string
            ID of the destination vertex. The source vertex ID is specified when the
            shortest path result is first computed.

        show : boolean
            Indicates whether the path should be plotted. Default is False.

        highlight : list
            If the path is plotted, identifies the vertices (by vertex ID) that
            should be highlighted by plotting in a different color.

        Returns
        -------
        path : list
            Containing pairs of (vertex_id, distance) on the path between the source and destination vertices.
        """
        if self._path_query_table is None:
            self._path_query_table = self._generate_path_sframe()

        source_vid = self._input['source_vid']
        path = []
        path_query_table = self._path_query_table
        if not vid in path_query_table['vid']:
            raise ValueError('Vertex id ' + str(vid) + ' not found')

        record = path_query_table[path_query_table['vid'] == vid][0]
        dist = record['distance']
        if dist > 1e5:
            raise ValueError('The distance to {} is too large to show the path.'.format(vid))

        path = [(vid, dist)]
        max_iter = len(path_query_table)
        num_iter = 0
        while record['distance'] != 0 and num_iter < max_iter:
            parent_id = record['parent_row_id']
            assert parent_id < len(path_query_table)
            assert parent_id >= 0
            record = path_query_table[parent_id]
            path.append((record['vid'], record['distance']))
            num_iter += 1
        assert record['vid'] == source_vid
        assert num_iter < max_iter
        path.reverse()

        if show is True and len(path) > 1:
            sub_g = _SGraph()
            for i, j in zip(path, path[1:]):
                sub_g = sub_g.add_edges(self['graph'].get_edges(src_ids=i, dst_ids=j),
                                        src_field='__src_id',
                                        dst_field='__dst_id')

            path_highlight = []

            if highlight is not None:

                if not isinstance(highlight, list):
                    raise TypeError, "Input 'highlight' must be a list."
                    
                path_names = set([x[0] for x in path])
                path_highlight = list(set.intersection(path_names, set(highlight)))

                plot = sub_g.show(vlabel='id', highlight=path_highlight, **kwargs)
                if _HAS_IPYTHON:
                    IPython.core.display.display(plot)

        return path

    def _get_wrapper(self):
        return lambda m: ShortestPathModel(self._input, m)

    def _generate_path_sframe(self):
        """
        Generates an sframe with columns: vid, parent_row_id, and distance.
        Used for speed up the path query.
        """
        source_vid = self._input['source_vid']
        weight_field = self._input['edge_attr']

        query_table = copy.copy(self.get('distance'))
        query_table.add_row_number('row_id')

        g = self.get('graph').add_vertices(query_table)
        # The sequence id which a vertex is visited, initialized with 0 meaning not visited.
        g.vertices['__parent__'] = -1
        weight_field = self._input['edge_attr']
        if (weight_field == ""):
            weight_field = '__unit_weight__'
            g.edges[weight_field] = 1

        # Traverse the graph once and get the parent row id for each vertex
        def traverse_fun(src, edge, dst):
            if src['__id'] == source_vid:
                src['__parent__'] = src['row_id']
            if dst['distance'] == src['distance'] + edge[weight_field]:
                dst['__parent__'] = max(dst['__parent__'], src['row_id'])
            return (src, edge, dst)

        g = g.triple_apply(traverse_fun, ['__parent__'])
        query_table = query_table.join(g.get_vertices()[['__id', '__parent__']], '__id').sort('row_id')
        query_table.rename({'__parent__': 'parent_row_id', '__id': 'vid'})
        return query_table


def create(graph, source_vid, weight_field="", verbose=True, plot=False):
    """
    Compute the single source shortest path distance from the source vertex to
    all vertices in the graph. Note that because SGraph is directed, shortest
    paths are also directed. To find undirected shortes paths add edges to the
    SGraph in both directions. Return a model object with distance each of vertex in the
    graph.

    Parameters
    ----------
    graph : SGraph
        The graph on which to compute shortest paths.

    source_vid : vertex ID
        ID of the source vertex.

    weight_field : string, optional
        The edge field representing the edge weights. If empty, uses unit
        weights.

    verbose : bool, optional
        If True, print progress updates.

    plot : bool, optional
        If True, display the progress plot.

    Returns
    -------
    out : Model
        A Model object that contains the shortest path distance for each vertex
        to the source vertex.

    Examples
    --------
    If given an :class:`~graphlab.SGraph` ``g``, we can create
    a :class:`~graphlab.shortest_path.ShortestPathModel` as follows:

    >>> sp = shortest_path.create(g, source_vid=123)

    We can obtain the shortest path from the source vertex to each vertex in
    the graph ``g`` as follows:

    >>> sp_sframe = sp['distance']   # SFrame

    We can obtain an auxiliary graph with additional information corresponding
    to the shortest path from the source vertex to each vertex in the graph
    ``g`` as follows:

    >>> sp_graph = sp.get['graph']       # SGraph
    """
    if not isinstance(graph, _SGraph):
        raise TypeError('graph input must be a SGraph object.')

    if plot is True:
        print "The plot functionality for shortest path is not yet implemented."

    opts = {'source_vid': source_vid, 'edge_attr': weight_field,
            'graph': graph.__proxy__}
    params = main.run('sssp', opts, verbose, plot)
    return ShortestPathModel({'source_vid': source_vid,
                             'edge_attr': weight_field}, params['model'])
