# -----------------------------------------------------------------------------
# pyinput.py -- Input information from Python programs.
#
# August 2014, Phil Connell
#
# Copyright 2014, Ensoft Ltd.
# -----------------------------------------------------------------------------

from __future__ import absolute_import, print_function

__all__ = (
    "trace",
)


import contextlib
import sys

from . import errors
from . import db_api


# Optional, should be checked at API entrypoints requiring entrails (and
# yes, the handling is a bit fugly).
try:
    import entrails
except ImportError:
    _entrails_available = False
    class entrails:
        EntrailsOutput = object
else:
    _entrails_available = True


class _SextantOutput(entrails.EntrailsOutput):
    """Record calls traced by entrails in a sextant database."""

    # Internal attributes:
    #
    # _conn:
    #   Sextant connection.
    # _fns:
    #   Stack of function names (implemented as a list), reflecting the current
    #   call stack, based on enter, exception and exit events.
    # _prog:
    #   Sextant program representation.
    _conn = None
    _fns = None
    _prog = None

    def __init__(self, conn, program_name):
        """
        Initialise this output.

        conn:
            Connection to the Sextant database.
        program_name:
            String used to refer to the traced program in sextant.

        """
        self._conn = conn
        self._fns = []
        self._prog = self._conn.new_program(program_name)
        self._tracer = self._trace()
        next(self._tracer)

    def _add_frame(self, event):
        """Add a function call to the internal stack."""
        name = event.qualname()
        self._fns.append(name)
        self._prog.add_function(name)

        try:
            prev_name = self._fns[-2]
        except IndexError:
            pass
        else:
            self._prog.add_function_call(prev_name, name)

    def _remove_frame(self, event):
        """Remove a function call from the internal stack."""
        assert event.qualname() == self._fns[-1], \
                "Unexpected event for {}".format(event.qualname())
        self._fns.pop()

    def _handle_simple_event(self, what, event):
        """Handle a single trace event, not needing recursive processing."""
        handled = True

        if what == "enter":
            self._add_frame(event)
        elif what == "exit":
            self._remove_frame(event)
        else:
            handled = False

        return handled

    def _trace(self):
        """Coroutine that processes trace events it's sent."""
        while True:
            what, event = yield

            handled = self._handle_simple_event(what, event)
            if not handled:
                if what == "exception":
                    # An exception doesn't necessarily mean the current stack
                    # frame is exiting. Need to check whether the next event is
                    # an exception in a different stack frame, implying that
                    # the exception is propagating up the stack.
                    while True:
                        prev_event = event
                        prev_name = event.qualname()
                        what, event = yield
                        if event == "exception":
                            if event.qualname() != prev_name:
                                self._remove_frame(prev_event)
                        else:
                            handled = self._handle_simple_event(what, event)
                            assert handled
                            break

                else:
                    raise NotImplementedError

    def close(self):
        self._prog.commit()

    def enter(self, event):
        self._tracer.send(("enter", event))

    def exception(self, event):
        self._tracer.send(("exception", event))

    def exit(self, event):
        self._tracer.send(("exit", event))


# @@@ config parsing shouldn't be done in __main__ (we want to get the neo4j
# url from there...)
@contextlib.contextmanager
def trace(conn, program_name=None, filters=None):
    """
    Context manager that records function calls in its context block.

    e.g. given this code:

        with sextant.trace("http://localhost:7474"):
            foo()
            bar()

    The calls to foo() and bar() (and their callees, at any depth) will be
    recorded in the sextant database.

    conn:
        Instance of SextantConnection that will be used to record calls.
    program_name:
        String used to refer to the traced program in sextant. Defaults to
        sys.argv[0].
    filters:
        Optional iterable of entrails filters to apply.

    """
    if not _entrails_available:
        raise errors.MissingDependencyError(
            "Entrails is required to trace execution")

    if program_name is None:
        program_name = sys.argv[0]

    tracer = entrails.Entrails(filters=filters)
    tracer.add_output(_SextantOutput(conn, program_name))

    tracer.start_trace()
    try:
        yield
    finally:
        # Flush traced data.
        tracer.end_trace()

