from __future__ import absolute_import, division, print_function

import networkx as nx
from datashape import discover
from .utils import expand_tuples, cls_name
from contextlib import contextmanager


ooc_types = set()  # Out-of-Core types


class NetworkDispatcher(object):
    def __init__(self, name):
        self.name = name
        self.graph = nx.DiGraph()

    def register(self, a, b, cost=1.0):
        sigs = expand_tuples([a, b])
        def _(func):
            for a, b in sigs:
                self.graph.add_edge(b, a, cost=cost, func=func)
            return func
        return _

    def path(self, *args, **kwargs):
        return path(self.graph, *args, **kwargs)

    def __call__(self, *args, **kwargs):
        return _transform(self.graph, *args, **kwargs)


def _transform(graph, target, source, excluded_edges=None, ooc_types=ooc_types, **kwargs):
    """ Transform source to target type using graph of transformations """
    x = source
    excluded_edges = excluded_edges or set()
    if 'dshape' not in kwargs:
        kwargs['dshape'] = discover(x)
    pth = path(graph, type(source), target, excluded_edges=excluded_edges)
    try:
        for (A, B, f) in pth:
            oldx = x
            x = f(x, **kwargs)
        return x
    except Exception as e:
        print("Failed on %s -> %s. Working around" %
                    (A.__name__,  B.__name__))
        new_exclusions = excluded_edges | set([(A, B)])
        return _transform(graph, target, source, excluded_edges=new_exclusions, **kwargs)


def path(graph, source, target, excluded_edges=None, ooc_types=None):
    """ Path of functions between two types """
    if not isinstance(source, type):
        source = type(source)
    if not isinstance(target, type):
        target = type(target)

    # If both source and target are Out-Of-Core types then restrict ourselves
    # to the graph of out-of-core types
    if ooc_types:
        oocs = tuple(ooc_types)
        if issubclass(source, oocs) and issubclass(target, oocs):
            oldgraph = graph
            graph = graph.subgraph([n for n in graph.nodes() if issubclass(n, oocs)])
    with without_edges(graph, excluded_edges) as g:
        pth = nx.shortest_path(g, source=source, target=target, weight='cost')
        result = [(source, target, graph.edge[source][target]['func'])
                    for source, target in zip(pth, pth[1:])]
    return result


@contextmanager
def without_edges(g, edges):
    edges = edges or []
    held = dict()
    for a, b in edges:
        held[(a, b)] = g.edge[a][b]
        g.remove_edge(a, b)

    try:
        yield g
    finally:
        for (a, b), kwargs in held.items():
            g.add_edge(a, b, **kwargs)
