from peak.util.assembler import *
import new
from copy import copy

def make_adapter(func, conditional, cost):
    if func is None:
        return Nop(cost)
    elif isbreak(func):
        return func
    elif func is Break:
        return Break(cost)
    else:
        #return Chain([func], conditional, cost)
        return Link(func, conditional, cost)

class Link(object):
    __slots__ = ['adapt', 'conditional', 'cost']
    def __init__(self, func, conditional, cost):
        self.adapt = func
        self.conditional = conditional
        self.cost = cost

    def __repr__(self):
        return self.__class__.__name__ + repr((self.adapt, self.conditional, self.cost))

    def merge(self, other):
        return Chain([self], self.cost).merge(other)
        #raise TypeError(other)

    def __add__(self, other):
        if islink(other):
            return Chain((self, other), self.cost + other.cost)
        else:
            return other.__radd__(self)
            #raise TypeError(other)

    def overload(self, other):
        if not self.conditional:
            return self
        else:
            return self
            #@@@@@@@@@@@@@@@@@@@@@
            return Conflict((self, other), self.cost)

    def __eq__(self, other):
        return (self.__class__ is other.__class__
                and self.adapt is other.adapt
                and self.cost == other.cost)


class NopBase(object):
    __slots__ = ['conditional', 'cost']
    def __init__(self, cost):
        self.conditional = False
        self.cost = cost

    def __repr__(self):
        return '%s(%d)' % (self.__class__.__name__, self.cost)

    def overload(self, other):
        return self

    def __eq__(self, other):
        return (self.__class__ is other.__class__ and self.cost == other.cost)


class Nop(NopBase):
    __slots__ = []

    def __add__(self, other):
        return change_cost(other, self.cost + other.cost)
    __radd__ = __add__

    @staticmethod
    def adapt(ob):
        return ob

class Break(NopBase):
    __slots__ = []
    @staticmethod
    def adapt(ob):
        return

    def __add__(self, other):
        return change_cost(self, self.cost + other.cost)
    __radd__ = __add__


class Combo(tuple):
    def __new__(cls, links, conditional, cost):
        self = super(Combo, cls).__new__(cls, links)
        self.conditional = conditional
        self.cost = cost
        return self

    def __eq__(self, other):
        return (self.__class__ is other.__class__
                and self.__getargs__() == other.__getargs__())

    def __getargs__(self):
        return list(self), self.conditional, self.cost

    def __repr__(self):
        return self.__class__.__name__ + repr(self.__getargs__())

    def __copy__(self):
        return self.__class__(*self.__getargs__())

    __adapt = None

    @property
    def adapt(self):
        if self.__adapt is None:
            self.__adapt = self.create_func()
        return self.__adapt




class Conflict(Combo):
    def __new__(cls, items, cost):
        return super(Conflict, cls).__new__(cls, items, Ellipsis, cost)

    def __getargs__(self):
        return list(self), self.cost

    def adapt(self, ob):
        raise ValueError(self)

    def merge(self, other):
        if islink(other) or ischain(other):
            ext = [other]
        elif isconflict(other):
            ext = list(other)
        else:
            raise TypeError(other)
        return Conflict(list(self) + ext, self.cost)

    def overload(self, other):
        if isnop(other) or isbreak(other):
            return change_cost(other, self.cost)
            #@@@@@@@
            raise ValueError(self, other)
        return self.merge(other)

    def __add__(self, other):
        if isbreak(other):
            return other + self
        items = [item + other for item in self]
        return Conflict(items, self.cost)

    def __radd__(self, other):
        items = [other + item for item in self]
        return Conflict(items, self.cost)


class Chain(Combo):
    def __new__(cls, links, cost):
        for link in links:
            if link.conditional:
                conditional = True
                break
        else:
            conditional = False
        return super(Chain, cls).__new__(cls, links, conditional, cost)

    def __getargs__(self):
        return list(self), self.cost

    def overload(self, other):
        if not self.conditional:
            return self
        elif isnop(other):
            #@@@ ?????
            return change_cost(other, self.cost)
        elif isbreak(other) or isconflict(other):
            return self
        elif islink(other) or ischain(other):
            return self
            #@@@@@@@@@@
            return Conflict((self, other), other.cost)
        else:
            raise TypeError(other)

    def merge(self, other):
        if isconflict(other):
            return other.merge(self)
        else:
            return self
            #@@@@@@@@@@@@@@@@@@@@@@@@@@@@
            return Conflict([self, other], self.cost)
        #raise TypeError(self, other)

    @staticmethod
    def _make_ext(other):
        if isnop(other):
            return []
        elif islink(other):
            return [other]
        elif ischain(other):
            return list(other)
        else:
            raise TypeError(other)

    def __add__(self, other):
        if isbreak(other):
            return other + self
        elif isconflict(other):
            #@@@@@@@@@@@@@@@@@@
            return Conflict([], self.cost + other.cost)
        return Chain(list(self) + self._make_ext(other), self.cost + other.cost)

    def __radd__(self, other):
        if isconflict(other):
            #@@@@@@@@@@@@@@@@@@
            return Conflict([], self.cost + other.cost)
        return Chain(self._make_ext(other) + list(self), self.cost + other.cost)


    def create_func(self):
        c = Code()

        if not self:
            c.LOAD_FAST('ob')
        elif self.conditional:
            last = self[-1]
            for link in self:
                c.LOAD_CONST(link.adapt)
                c.LOAD_FAST('ob')
                c.CALL_FUNCTION(1)
                if link is not last:
                    c.DUP_TOP()
                    c.STORE_FAST('ob')
                    #c.LOAD_FAST('ob')
                    c.LOAD_CONST(None)
                    c.COMPARE_OP('is')
                    skip_ref = c.JUMP_IF_FALSE()
                    c.LOAD_CONST(None)
                    c.RETURN_VALUE()
                    skip_ref()
        else:
            for link in reversed(self):
                c.LOAD_CONST(link.adapt)
            c.LOAD_FAST('ob')
            for _ in self:
                c.CALL_FUNCTION(1)

        c.RETURN_VALUE()
        c.co_argcount = 1
        return new.function(c.code(), {})


## class Branches(Combo):
##     def create_func(self):
##         c = Code()
##         last = self[-1]
##         for link in self:
##             c.LOAD_CONST(func.adapt)
##             c.LOAD_FAST('ob')
##             c.CALL_FUNCTION(1)
##             if link is not last:
##                 c.STORE_FAST('result')
##                 c(Compare(Local('result'), [('is', None)]))
##                 skip_ref = c.JUMP_IF_TRUE()
##                 c(Return(Local('result')))
##                 skip_ref()
##         c.RETURN_VALUE()
##         c.co_argcount = 1
##         return new.function(c.code(), {})
##


def change_cost(x, cost):
    x = copy(x)
    x.cost = cost
    return x

def isnop(x):
    return type(x) is Nop

def isbreak(x):
    return type(x) is Break

def islink(x):
    return type(x) is Link

def ischain(x):
    return type(x) is Chain

def isbranches(x):
    return type(x) is Branches

def isconflict(x):
    return type(x) is Conflict


def chain(a, b):
    return a + b

    cost = a.cost + b.cost
    conditional = a.conditional or b.conditional
    if isbreak(a) or isbreak(b):
        return Break(cost)
    elif isnop(a):
        return change_cost(b, cost)
    elif isnop(b):
        return change_cost(a, cost)
    elif islink(a):
        return chain(a.aschain(), b)
    elif islink(b):
        return chain(a, b.aschain())
    elif ischain(a) and ischain(b):
        if len(a) or len(b):
            return Chain(a + b, cost)
        else:
            return NOP(cost)
        #adapters = combo(Chain, a, b)
        #return Chain(adapters, conditional, cost)
    elif isinstance(a, Conflict):
        items = [chain(item, b) for item in a]
        return Conflict(items, conditional, cost)
    elif isinstance(b, Conflict):
        items = [chain(a, item) for item in b]
        return Conflict(items, conditional, cost)
    else:
        raise ValueError(a, b)

def merge(a, b):
    if a == b or isnop(a) or isbreak(b):
        return a
    elif isbreak(a) or isnop(b):
        return b
    else:
        return a.merge(b)

def overload(a, b):
    if a == b:
        return a
    elif a.cost == b.cost:
        return merge(a, b)
    elif a.cost > b.cost:
        return b.overload(a)
    else:
        return a.overload(b)

    if isnop(a) or isnop(b):
        return Nop(min(len(a), len(b)))
    if (ischain(a) or islink(a)) and (ischain(b) or islink(b)):
        if a.cost > b.cost:
            a,b = b,a
        if not a.conditional:
            if not b.conditional and a.cost == b.cost and len(b) < len(a):
                return b
            return a
        else:
            return Conflict((a, b), True, a.cost)
    elif isinstance(b, Conflict):
        return Conflict(a + b, True, a.cost)
    else:
        raise ValueError(a, b)



    if a == b:
        if a.cost <= b.cost:
            return a
        else:
            return b
    adapters = combo(Branches, a, b)
    adapters.sort(key=lambda adapter: adapter.cost)

    if adapters[0].cost == adapters[1].cost:
        return Conflict((adapters[0], adapters[1]), False, adapters[0].cost)
        #raise ValueError(adapters[0], adapters[1])
    else:
        return adapters[0]
        #adapters = adapters[:1]

    for i, adapter in enumerate(adapters):
        if not adapter.conditional:
            adapters = adapters[:i+1]
            break
    if i:
        return Branches(adapters, adapter.conditional, adapter.cost)
    else:
        return adapter




if __name__ == '__main__':
    f = lambda o: o
    a = make_adapter(f, False, 1)
    chain = chain(a, a)
    print overload(chain, Nop(1))

