import inspect
import logging

import pygame.color
import pygame.display
import pygame.draw
import pygame.font
import talljosh
from twisted.internet.defer import returnValue

from fibre.ui import events

log = logging.getLogger(__name__)

class ScreenError(Exception):
    pass

class DeclarationError(Exception):
    pass

class ConstraintError(Exception):
    pass

def annotate(**kwargs):
    def wrapper(fn):
        for key, value in kwargs.iteritems():
            setattr(fn, key, value)
        return fn
    return wrapper

class SysFont(object):
    def __init__(self, names, bold=False, italic=False):
        self.names = names
        self.bold = bold
        self.italic = italic

    def construct(self, size):
        return pygame.font.SysFont(self.names, size, self.bold, self.italic)

class Theme(object):
    class Colours(object):
        for _name, _value in pygame.color.THECOLORS.iteritems():
            locals()[_name] = _value
        transparent = (255, 255, 255, 0)

    class Fonts(object):
        default = ('sans', 20)

    class FontFiles(object):
        sans = SysFont('Arial,Helvetica,FreeSans')

class DefaultTheme(Theme):
    class Colours(Theme.Colours):
        window = Theme.Colours.white
        text = Theme.Colours.black
        button_highlight = Theme.Colours.gray85
        button_shadow = Theme.Colours.gray25
        button_background = Theme.Colours.gray50
        button_hover = Theme.Colours.gray75

class BoundDescriptor(object):
    def __init__(self, base, descriptor):
        self.base = base
        self.descriptor = descriptor

def bound_descriptor(base, descriptor):
    '''
    Creates a bound descriptor of the appropriate kind, or returns the existing
    one if that descriptor has already been bound to the base object or class.
    '''
    try:
        binding_map = base._binding_map
    except AttributeError:
        binding_map = base._binding_map = {}
    if descriptor in binding_map:
        return binding_map[descriptor]

    if not getattr(descriptor, 'bound_descriptor', None):
        class_ = descriptor.__class__
        class_.bound_descriptor = type(
            'Bound{}'.format(class_.__name__),
            (class_.bound_base_class,),
            {},
        )
    result = descriptor.bound_descriptor(base, descriptor)
    binding_map[descriptor] = result
    return result

class BindableDescriptor(object):
    bound_base_class = BoundDescriptor
    bindable = True

    def __get__(self, instance, owner):
        if not instance:
            return self
        return bound_descriptor(instance, self)

    def __set__(self, base, value):
        bound_descriptor(base, self).set_value(value)

class RectAttrDescriptor(BindableDescriptor):
    '''
    Maps to a pygame.Rect attribute.
    '''
    def __init__(self, rect_attribute):
        self.rect_attribute = rect_attribute

class Bound1DAnchor(BoundDescriptor):
    value = None
    def set_value(self, value):
        self.value = value

    def resolve(self, element):
        return getattr(element.rect, self.descriptor.rect_attribute)

class Anchor1D(RectAttrDescriptor):
    bound_base_class = Bound1DAnchor

class XAnchor(Anchor1D):
    pass

class YAnchor(Anchor1D):
    pass

class ScaledValue(object):
    def __init__(self, base, scale, r):
        self.base = base
        self.scale = scale
        self.r = r

    def resolve(self, element):
        a = bound_descriptor(self.base, self.scale.min).resolve(element)
        b = bound_descriptor(self.base, self.scale.max).resolve(element)
        return a * (1 - self.r) + b * self.r

class BoundScale(object):
    def __init__(self, base, scale):
        self.base = base
        self.scale = scale

    def __call__(self, r):
        return ScaledValue(self.base, self.scale, r)

class Scale(BindableDescriptor):
    bound_base_class = BoundScale

    def __init__(self, min_, max_):
        self.min = min_
        self.max = max_

class XScale(Scale):
    pass

class YScale(Scale):
    pass

class BoundAnchor(BoundDescriptor):
    def set_value(self, value):
        self.x.set_value(value.x)
        self.y.set_value(value.y)

    @property
    def x(self):
        return bound_descriptor(self.base, self.descriptor.x)

    @property
    def y(self):
        return bound_descriptor(self.base, self.descriptor.y)

class Anchor(BindableDescriptor):
    bound_descriptor = BoundAnchor

    def __init__(self, x, y):
        self.x = x
        self.y = y

class Positioned(object):
    top = YAnchor('top')
    centre_y = YAnchor('centery')
    bottom = YAnchor('bottom')
    height = YAnchor('height')

    left = XAnchor('left')
    centre_x = XAnchor('centerx')
    right = XAnchor('right')
    width = XAnchor('width')

    top_left = Anchor(left, top)
    mid_top = Anchor(centre_x, top)
    top_right = Anchor(right, top)
    mid_left = Anchor(left, centre_y)
    centre = Anchor(centre_x, centre_y)
    mid_right = Anchor(right, centre_y)
    bottom_left = Anchor(left, bottom)
    mid_bottom = Anchor(centre_x, bottom)
    bottom_right = Anchor(right, bottom)

    x_position = XScale(left, right)
    y_position = YScale(top, bottom)

_widgets_to_attach = []
def attach_widget(widget):
    '''
    Called during widget definition construction. Attaches the widget to the next Screen class that is completed. This is not pretty, but there's no nice way of doing it.
    '''
    frame = inspect.currentframe().f_back.f_back
    while frame:
        if (frame.f_code.co_flags^1)&3 == 3:
            break
        frame = frame.f_back
    else:
        raise DeclarationError('Widgets must be defined in a class to take effect')
    _widgets_to_attach.append(widget)

_child_widgets = set()
def seen_child_widget(widget):
    '''
    Called during widget definition construction when child widgets are attached. This ensures taht the child widgets are not attached to the Screen because they're already attached to their parent widget.
    '''
    _child_widgets.add(widget)

class ScreenMeta(type):
    '''
    When a Screen class is declared, creates an elements attribute which has all the defined elements in the order defined.
    '''
    def __new__(metaclass, name, bases, dict_):
        dict_['elements'] = [e for e in _widgets_to_attach if e not in _child_widgets]
        _widgets_to_attach[:] = []
        _child_widgets.clear()

        return type.__new__(metaclass, name, bases, dict_)

class Screen(object):
    __metaclass__ = ScreenMeta
    theme = DefaultTheme
    keyboard_mapping = talljosh.MultiMap()
    caption = 'fibre'

    @classmethod
    def show(cls):
        return window.show_screen(cls)

ActionTrigger = events.EventKind()
ButtonPress = events.EventKind()
MouseIn = events.EventKind()
MouseOut = events.EventKind()

class LiveWidget(object):
    rect = None
    default_position = (0, 0)
    default_size = (16, 16)
    monitor_mouse = False

    def __init__(self, screen, widget):
        self.screen = screen
        self.widget = widget
        self.window = self.screen.window
        self.display = self.window.display

    def draw(self):
        self.widget.draw(self.display, self.rect)

    def update_position(self):
        left, width = self.resolve_constraints(
            self.widget.left.value,
            self.widget.centre_x.value,
            self.widget.right.value,
            self.widget.width.value,
            self.get_default_left,
            self.get_default_width,
        )
        top, height = self.resolve_constraints(
            self.widget.top.value,
            self.widget.centre_y.value,
            self.widget.bottom.value,
            self.widget.height.value,
            self.get_default_top,
            self.get_default_height,
        )
        self.rect = pygame.Rect(left, top, width, height)

    def get_default_top(self):
        return self.default_position[0]

    def get_default_left(self):
        return self.default_position[1]

    def get_default_width(self):
        return self.default_size[0]

    def get_default_height(self):
        return self.default_size[1]

    def resolve_constraints(self, a, b, c, span, min_getter, span_getter):
        '''
        Takes min, mid, max and range, some of which should be None, and returns
        min, range.
        '''
        if a and b:
            u = self.screen.get_value(a)
            v = 2 * (self.screen.get_value(b) - u)
        elif a and c:
            u = self.screen.get_value(a)
            v = self.screen.get_value(c) - u
        elif a and span:
            u = self.screen.get_value(a)
            v = self.screen.get_value(span)
        elif b and c:
            w = self.screen.get_value(c)
            v = 2 * (w - self.screen.get_value(b))
            u = w - v
        elif b and span:
            v = self.screen.get_value(span)
            u = self.screen.get_value(b) - v // 2
        elif c and span:
            v = self.screen.get_value(span)
            u = self.screen.get_value(c) - v
        elif a:
            u = self.screen.get_value(a)
            v = span_getter()
        elif b:
            v = span_getter()
            u = self.screen.get_value(b) - v // 2
        elif c:
            v = span_getter()
            u = self.screen.get_value(c) - v
        elif span:
            u = min_getter()
            v = self.screen.get_value(span)
        else:
            raise ConstraintError('{} object could not be positioned'.format(self.widget))

        return u, v

    def handle_event(self, event):
        pass

class ProxyProperty(object):
    def __init__(self, widget_attribute=None):
        if widget_attribute is not None:
            self.widget_attribute = widget_attribute

    def __get__(self, instance, owner):
        if instance is None:
            return self
        try:
            return getattr(instance, ' ' + self.widget_attribute)
        except AttributeError:
            pass

        self.__set__(instance, getattr(instance.widget, self.widget_attribute))
        return getattr(instance, ' ' + self.widget_attribute)

    def __set__(self, instance, value):
        result = self.resolve(instance, value)
        setattr(instance, ' ' + self.widget_attribute, result)
        instance.changed = True

    def resolve(self, instance, value):
        return value

class ResolvedProperty(ProxyProperty):
    def resolve(self, instance, value):
        resolver = getattr(instance.window, self.resolver_function)
        return resolver(value)

class ColourProperty(ResolvedProperty):
    widget_attribute = 'colour'
    resolver_function = 'colour'

class FontProperty(ResolvedProperty):
    widget_attribute = 'font'
    resolver_function = 'font'

def render_font(font, text, antialias, colour, background=None):
    '''
    Wrapper around pygame.font.Font.render() because it's silly.
    '''
    if background:
        return font.render(text, antialias, colour, background)
    else:
        return font.render(text, antialias, colour)

class Widget(Positioned):
    live_widget_factory = LiveWidget

    def __init__(self, **kwargs):
        for key, value in kwargs.iteritems():
            setattr(self, key, value)

        attach_widget(self)

    def draw(self, display, rect):
        raise NotImplementedError

class Box(Widget):
    fill = None
    border = None
    border_width = 1

    class live_widget_factory(LiveWidget):
        fill = ColourProperty('fill')
        border = ColourProperty('border')
        border_width = ProxyProperty('border_width')

        def draw(self):
            if self.fill:
                self.display.fill(self.fill, rect=self.rect)
            if self.border:
                pygame.draw.rect(self.display, self.border, self.rect, self.border_width)

class Button(Widget):
    font = 'default'
    background_colour = 'button_background'
    border_highlight = 'button_highlight'
    border_shadow = 'button_shadow'
    text_colour = 'text'
    background_hover = 'button_hover'
    text_hover = None

    text = 'Button'
    border_width = 2
    default_padding = 5

    action = None

    def __init__(self, text=None, **kwargs):
        if text:
            self.text = text
        super(Button, self).__init__(**kwargs)

    class live_widget_factory(LiveWidget):
        changed = True
        mouse_over = False
        mouse_down = False

        font = FontProperty()
        background_colour = ColourProperty('background_colour')
        border_highlight = ColourProperty('border_highlight')
        border_shadow = ColourProperty('border_shadow')
        text_colour = ColourProperty('text_colour')
        background_hover = ColourProperty('background_hover')
        text_hover = ColourProperty('text_hover')
        text = ProxyProperty('text')
        default_padding = ProxyProperty('default_padding')
        border_width = ProxyProperty('border_width')
        action = ProxyProperty('action')

        @property
        def canvas(self):
            if self.changed:
                self.redraw()
            return self._canvas

        @property
        def hover_canvas(self):
            if self.changed:
                self.redraw()
            return self._hover_canvas

        @property
        def click_canvas(self):
            if self.changed:
                self.redraw()
            return self._click_canvas

        def redraw(self):
            self._canvas = self.draw_button(
                text_colour = self.text_colour,
                background_colour = self.background_colour,
                border_highlight = self.border_highlight,
                border_shadow = self.border_shadow,
            )
            self._hover_canvas = self.draw_button(
                text_colour = self.text_hover or self.text_colour,
                background_colour = self.background_hover or self.background_colour,
                border_highlight = self.border_highlight,
                border_shadow = self.border_shadow,
            )
            self._click_canvas = self.draw_button(
                text_colour = self.text_hover or self.text_colour,
                background_colour = self.background_hover or self.background_colour,
                border_highlight = self.border_shadow,
                border_shadow = self.border_highlight,
                text_offset = (self.border_width + 1) // 2,
            )
            self.changed = False

        def draw_button(self, text_colour, background_colour, border_highlight, border_shadow, text_offset=0):
            canvas = pygame.Surface(self.rect.size, pygame.SRCALPHA)
            text = render_font(
                font=self.font,
                text=self.text,
                antialias=True,
                colour=text_colour,
                background=background_colour,
            )
            if background_colour:
                canvas.fill(background_colour)
            width = canvas.get_width()
            height = canvas.get_height()
            x = (width - text.get_width()) // 2 + text_offset
            y = (height - text.get_height()) // 2 + text_offset
            canvas.blit(text, (x, y))

            canvas.lock()
            canvas.fill(border_highlight, pygame.Rect(
                (0, 0),
                (width, self.border_width),
            ))
            canvas.fill(border_highlight, pygame.Rect(
                (0, 0),
                (self.border_width, height),
            ))
            pygame.draw.polygon(canvas, border_shadow, [
                (width, 0),
                (width, height),
                (0, height),
                (self.border_width, height - self.border_width),
                (width - self.border_width, height - self.border_width),
                (width - self.border_width, self.border_width),
            ])
            canvas.unlock()

            return canvas

        def draw(self):
            if self.mouse_down and self.mouse_over:
                canvas = self.click_canvas
            elif self.mouse_over:
                canvas = self.hover_canvas
            else:
                canvas = self.canvas

            self.display.blit(canvas, self.rect)

        def get_default_width(self):
            font_size = self.font.size(self.text)
            return font_size[0] + self.default_padding * 2

        def get_default_height(self):
            font_size = self.font.size(self.text)
            return font_size[1] + self.default_padding * 2

        def clicked(self):
            if self.action:
                event = events.Event(ActionTrigger, action=self.action)
                events.queue.push(event)
                if not event.propagating:
                    return
            events.queue.push(events.Event(ButtonPress, button=self))

        class handle_event(events.EventHandler):
            @annotate(handles=[MouseIn])
            def mouse_in(self, event):
                self.mouse_over = True
                event.caught()
            @annotate(handles=[MouseOut])
            def mouse_out(self, event):
                self.mouse_over = False
                event.caught()
            @annotate(handles=[events.MouseDown])
            def mouse_down(self, event):
                self.mouse_down = True
                event.caught()
            @annotate(handles=[events.MouseUp])
            def mouse_up(self, event):
                self.mouse_down = False
                event.caught()
                if self.mouse_over:
                    self.clicked()

class Text(Widget):
    font = 'default'
    colour = 'text'
    background_colour = None

    def __init__(self, text='Button', **kwargs):
        self.text = text
        super(Text, self).__init__(**kwargs)

    class live_widget_factory(LiveWidget):
        changed = True

        background_colour = ColourProperty('background_colour')
        colour = ColourProperty()
        font = FontProperty()
        text = ProxyProperty('text')

        @property
        def canvas(self):
            if self.changed:
                self._canvas = render_font(
                    font=self.font,
                    text=self.text,
                    antialias=True,
                    colour=self.colour,
                    background=self.background_colour,
                )
                self.changed = False
            return self._canvas

        def draw(self):
            self.display.blit(self.canvas, self.rect)

        def get_default_width(self):
            return self.canvas.get_width()

        def get_default_height(self):
            return self.canvas.get_height()

class Canvas(Widget):
    background_colour = 'transparent'

    class live_widget_factory(LiveWidget):
        changed = True

        background_colour = ColourProperty('background_colour')

        @property
        def canvas(self):
            if self.changed:
                self._canvas = pygame.Surface(self.rect.size, pygame.SRCALPHA)
                self._canvas.fill(self.background_colour)
                self.changed = False
            return self._canvas

        def draw(self):
            self.display.blit(self.canvas, self.rect)

        def clear(self):
            self.canvas.fill(self.background_colour)

        def draw_circle(self, colour, pos, radius, width=0):
            colour = self.window.colour(colour)
            pygame.draw.circle(self.canvas, colour, pos, radius, width)

        def draw_rect(self, colour, rect, width=0):
            colour = self.window.colour(colour)
            pygame.draw.rect(self.canvas, colour, rect, width)

        def draw_polygon(self, colour, points, width=0):
            colour = self.window.colour(colour)
            pygame.draw.polygon(self.canvas, colour, points, width)

        def draw_ellipse(self, colour, rect, width=0):
            colour = self.window.colour(colour)
            pygame.draw.ellipse(self.canvas, colour, rect, width)

        def draw_arc(self, colour, rect, start_angle, stop_angle, width=1):
            colour = self.window.colour(colour)
            pygame.draw.arc(self.canvas, colour, rect, start_angle, stop_angle, width)

        def draw_line(self, colour, start_pos, end_pos, width=1, antialias=False, blend=True):
            '''
            Draws a line on the canvas.
            @param width: line width. Only has effect if antialias is false.
            @param blend: whether to blend shades with background. Only has effect if antialias is true.
            '''
            colour = self.window.colour(colour)
            if antialias:
                pygame.draw.aaline(self.canvas, colour, start_pos, end_pos, blend)
            else:
                pygame.draw.line(self.canvas, colour, start_pos, end_pos, width)

        def draw_lines(self, colour, points, width=1, closed=False, antialias=False, blend=True):
            '''
            @param width: line width. Only has effect if antialias is false.
            @param blend: whether to blend shades with background. Only has effect if antialias is true.
            '''
            colour = self.window.colour(colour)
            if antialias:
                pygame.draw.aalines(self.canvas, colour, closed, points, blend)
            else:
                pygame.draw.lines(self.canvas, colour, closed, points, width)

class Window(Positioned):
    _display = None
    size = (800, 600)
    current_screen = None
    rect = None
    caption = 'fibre'

    @property
    def display(self):
        if self._display is None:
            pygame.display.set_caption(self.caption)
            self._display = pygame.display.set_mode(self.size)
            self.rect = pygame.Rect((0, 0), self.size)
        return self._display

    def show_screen(self, screen):
        if self.current_screen:
            self.current_screen.active = False
        self.current_screen = ActiveScreen(self, screen)
        self.current_screen.activated()
        return self.current_screen

    def tick(self):
        if self.current_screen:
            self.current_screen.draw()

    _font_cache = {}
    def font(self, value):
        '''
        Resolves the given value into a pygame font object.
        '''
        if hasattr(value, 'render'):
            return value
        if isinstance(value, tuple):
            font, size = value
        else:
            font, size = getattr(self.current_screen.theme.Fonts, value)

        font = self.font_file(font)

        if value not in self._font_cache:
            if (font, size) not in self._font_cache:
                if hasattr(font, 'construct'):
                    font = font.construct(size)
                else:
                    font = pygame.font.Font(font, size)
                self._font_cache[(font, size)] = font
            self._font_cache[value] = self._font_cache[(font, size)]
        return self._font_cache[value]

    def colour(self, value):
        if isinstance(value, tuple):
            return value
        if value is None:
            return value
        return getattr(self.current_screen.theme.Colours, value)

    def font_file(self, value):
        if hasattr(value, 'construct'):
            return value
        return getattr(self.current_screen.theme.FontFiles, value)

    def handle_event(self, event):
        if self.current_screen:
            self.current_screen.handle_event(event)

window = Window()

class ActiveScreen(object):
    def __init__(self, window, screen):
        self.window = window
        self.screen = screen
        self.active = False
        self.elements_need_updating = True
        self.element_by_widget = {}
        self.theme = self.screen.theme

    def __getattr__(self, name):
        try:
            return getattr(self.window.rect, name)
        except AttributeError:
            pass

        try:
            item = getattr(self.screen, name)
        except AttributeError:
            raise AttributeError('{} object has no attribute {!r}'.format(self.__class__.__name__, name))
        try:
            item = self.get_live_widget(item)
        except (KeyError, TypeError):
            pass
        return item

    def activated(self):
        self.active = True
        pygame.display.set_caption(self.caption)

    def draw(self):
        if not self.active:
            raise ScreenError('draw() method called on inactive screen')

        for element in self.elements:
            try:
                element.draw()
            except:
                log.error('Error while drawing %s', element, exc_info=True)
        pygame.display.flip()

    @property
    def elements(self):
        if self.elements_need_updating:
            self.perform_elements_update()
        return self._elements

    def perform_elements_update(self):
        self._elements = []
        for e in self.screen.elements:
            live_widget = e.live_widget_factory(self, e)
            self._elements.append(live_widget)
            self.element_by_widget[e] = live_widget

        self.elements_need_updating = False
        self.update_element_positions()

    def get_live_widget(self, widget):
        if self.elements_need_updating:
            self.perform_elements_update()
        return self.element_by_widget[widget]

    def update_element_positions(self):
        for element in self.elements:
            element.update_position()

    def wait_for_button_press(self, catch=False):
        event = yield events.queue.next(kinds=[ButtonPress])
        if catch:
            event.caught()
        returnValue(event.button.widget)

    def wait_for_action(self, catch=False):
        event = yield events.queue.next(kinds=[ActionTrigger])
        if catch:
            event.caught()
        returnValue(event.action)

    def get_value(self, value):
        '''
        Resolves the given position/size declaration to an actual position or
        size.
        '''
        if isinstance(value, (int, float, long)):
            return value

        if value.base == self.window:
            element = value.base
        else:
            element = self.get_live_widget(value.base)
        if not element.rect:
            raise ConstraintError('Reference to unplaced {}'.format(value.base))
        return value.resolve(element)

    def get_keyboard_status(self, key_code):
        keys = list(self.keyboard_mapping.get(key_code))
        if key_code in events.key_names:
            keys.append(events.key_names[key_code])
        keys.append(key_code)

        return any(events.keyboard_status.get(key) for key in keys)

    _mouse_in = None
    _mouse_down = None
    class handle_event(events.EventHandler):
        @annotate(handles=[events.MouseMotion])
        def mouse_move(self, event):
            for element in reversed(self.elements):
                if element.monitor_mouse:
                    element.handle_event(event)
                if not event.propagating:
                    break

            mouse_in = None
            for element in reversed(self.elements):
                if element.rect.collidepoint(event.pos):
                    mouse_in = element
                    break
            if mouse_in != self._mouse_in:
                if self._mouse_down:
                    # MouseIn and MouseOut events are only sent to the clicked
                    # element while the mouse is being held down.
                    if self._mouse_in == self._mouse_down:
                        self._mouse_in.handle_event(events.Event(MouseOut))
                    elif mouse_in == self._mouse_down:
                        mouse_in.handle_event(events.Event(MouseIn))
                else:
                    if self._mouse_in:
                        self._mouse_in.handle_event(events.Event(MouseOut))
                    if mouse_in:
                        mouse_in.handle_event(events.Event(MouseIn))
            self._mouse_in = mouse_in

        @annotate(handles=[events.MouseDown])
        def mouse_down(self, event):
            handled = False
            for element in reversed(self.elements):
                if element.monitor_mouse:
                    element.handle_event(event)
                    if element == self._mouse_in:
                        handled = True
                if not event.propagating:
                    return

            self._mouse_down = self._mouse_in
            if self._mouse_down and not handled:
                self._mouse_down.handle_event(event)

        @annotate(handles=[events.MouseUp])
        def mouse_up(self, event):
            handled = False
            target = self._mouse_down
            for element in reversed(self.elements):
                if element.monitor_mouse:
                    element.handle_event(event)
                    if element == target:
                        handled = True
                if not event.propagating:
                    break

            if self._mouse_down:
                self._mouse_down = None
                if not handled:
                    target.handle_event(event)

                if self._mouse_in and self._mouse_in != target:
                    # The element we're hovering over hasn't received its
                    # MouseIn because the mouse button was being held.
                    self._mouse_in.handle_event(events.Event(MouseIn))

