from __future__ import with_statement, absolute_import
import gc
from threading import local, RLock
from functools import wraps
from mext.const import const


__all__ = ['Controller', 'SnapshotController', 'LocalController', 'DisabledTransaction']



class DisabledTransaction(object):
    def __getattr__(self, name):
        raise AssertionError("Disabled transactions don't have .%s" % name)

    #def used(self, cell):
    #    raise AssertionError("Tried to use %r", cell)1
    #def changed(self, cell):
    #    raise AssertionError("Tried to change %r", cell)


class DummyCtxManager(object):
    def __init__(self, value=None):
        self.value = value
    def __enter__(self): return self.value
    def __exit__(self, *exc_info): pass


class Controller(object):
    def __new__(cls, *args, **kw):
        inst = super(Controller, cls).__new__(cls)
        if isinstance(inst, local):
            inst.init(*args, **kw)
        return inst

    def __init__(self, *args, **kw):
        self.txn = None
        if not isinstance(self, local):
            self.init(*args, **kw)

    def init(self, transaction_cls):
        """
            Redefine init instead of __init__ for operations that only
            have to be done once even for threadlocal implementations.
        """
        self.transaction_cls = transaction_cls

    def need_txn(self):
        """
            Context manager mandating a transaction to be present.
            Does nothing if a txn is already active, identical to .new_txn() otherwise.

                >>> with ctrl.need_txn() as txn:
                ...     assert ctrl.txn is txn

            NB: new_txn defined in LocalController only :(
        """
        if self.txn:
            return DummyCtxManager(self.txn)
        else:
            return self.new_txn()

##     @contextmanager
##     def new_txn(self):
##         if self.txn:
##             raise RuntimeError("Can't nest transactions")
##
##         txn = self.transaction_cls(self)
##         try:
##             try:
##                 with txn:
##                     self.txn = txn
##                     yield txn
##             finally:
##                 self.txn = DisabledTransaction()
##             # on_success / on_failure are separate from __exit__
##             # so they are called after txn is complete, knowing
##             # if txn.__exit__ handled any possible errors or not
##             txn.on_success()
##         except:
##             txn.on_failure()
##             raise
##         finally:
##             self.txn = None
##
##



class SnapshotController(Controller):
    def init(self, *args, **kw):
        self._lock = RLock()
        self._snapshots = []
        self._dead_keys = []
        self.discard_key = self._dead_keys.append
        self._set_state({})
        super(SnapshotController, self).init(*args, **kw)

    def make_snapshot(self):
        with self._lock:
            snapshot = self.state
            self._incref(snapshot)
            return snapshot

    def drop_snapshot(self, snapshot):
        """
            Drops a snapshot because a txn has updated to a newer one
        """
        with self._lock:
            self._decref(snapshot)

    def upd_state(self, txn):
        """
            Update state with a snapshot and writes of a txn.
            The txn must be destroyed immediately afterwards.
        """
        with self._lock:
            if txn.snapshot is not self.state:
                return False
            # drop prev state
            self._decref(self.state)
            state = self.state.copy()
            state.update(txn.writes)
            self._set_state(state)
            return True

    def collect_garbage(self):
        with self._lock:
            self._collect_garbage()

    def _collect_garbage(self):
        while self._dead_keys:
            key = self._dead_keys.pop()
            for snapshot in self._snapshots:
                snapshot.pop(key, None)


    def _prepare_collect(self):
        """
            Destroy state saving cell values in a special _gc_value attributes.
            Return list of valid cellkeys.
        """
        assert not self.txn
        assert self.state['refcount'] == 1
        assert len(self._snapshots) == 1
        # make sure we don't encounter dead keys in the first loop below
        self._collect_garbage()
        # drop reference to state from ._snapshots
        self.drop_snapshot(self.state)
        assert not self._snapshots
        # make sure all state keys are cellkeys
        del self.state['refcount']
        keys = self.state.keys()
        for k,v in self.state.iteritems():
            k.cell()._gc_value = v
        # make sure current state is not referenced anywhere but the controller itself
        #assert gc.get_referrers(self.state) == [self]
        # and remove that reference
        self.state = None
        return keys

    @staticmethod
    def _rebuild_state(keys):
        """
            Given a list of keys rebuild state using their cells' ._gc_value
            Return generator to be used in dict().
            Dead keys are excluded from the state.
        """
        for k in keys:
            c = k.cell()
            if c:
                yield k, c._gc_value
                c._gc_value = None

    def collect_cycles(self, generation=None, loops=1):
        """
            Collect circular garbage that has transactional state dict as one of the links.
            (This is impossible without calling this method and cannot be done automatically).
            Return number of cells that were garbage-collected.
        """
        with self._lock:
            keys = self._prepare_collect()
            # for some reason one pass doesn't collect everything
            # so we can either loop here or outside -- see tgarbage.py
            for i in xrange(loops):
                if not gc.collect(generation):
                    break
            state = dict(self._rebuild_state(keys))
            count_collected = len(keys) - len(state)
            self._set_state(state)
            return count_collected






    def _set_state(self, state):
        # this must be all-new state
        self.state = state
        self._snapshots.append(state)
        state['refcount'] = 1

    def _incref(self, snapshot):
        snapshot['refcount'] += 1


    def _decref(self, snapshot):
        snapshot['refcount'] -= 1
        rc = snapshot['refcount']
        assert rc >= 0
        if not rc:
            self._snapshots.remove(snapshot)







EXIT = const()

class TxnManager(object):
    """
        Wrap a generator to create reusable context manager
    """
    disabled = False
    def __init__(self, gen):
        if self.disabled:
            raise RuntimeError("Transaction manager has previously failed with an unexpected error. Transactions are disabled")
        self.gen = gen

    def __enter__(self):
        return self.gen.next()

    def __exit__(self, *exc_info):
        #try:
        if exc_info[0]:
            r = self.gen.throw(*exc_info)
        else:
            r = self.gen.next()
        assert r is EXIT
        #self.gen.close()
        #except:
        #    self.__class__.disabled = True
        #    raise

    @classmethod
    def wrap(cls, func):
        @wraps(func)
        def wrapped(ctrl):
            return cls(func(ctrl))
        return wrapped


class LocalController(SnapshotController, local):
    __slots__ = ('state', '_lock', '_snapshots', '_dead_keys', 'transaction_cls')
    @TxnManager.wrap
    def new_txn(self):
        if self.txn:
            raise RuntimeError("Can't nest transactions (current txn: %r)" % self.txn)
        with self._lock:
            with self._new_txn() as txn:
                yield txn
            if self._dead_keys:
                self._collect_garbage()
        yield EXIT

    @TxnManager.wrap
    def _new_txn(self):
        txn = self.transaction_cls(self)
        try:
            with txn as self.txn:
                yield txn
            success = True
        except:
            success = False
        with self._disabled_txn():
            if success:
                txn.on_success()
                txn.post_txn() #@@ someday we'll move post_txn outside of the lock
            else:
                txn.on_failure()
                raise
        yield EXIT


    @TxnManager.wrap
    def init_txn(self):
        txn = self.txn
        if not txn:
            with self.new_txn() as txn:
                yield txn
        else:
            cur_txn = self.txn
            try:
                with self._new_txn() as txn:
                    yield txn
                #print txn.crule
                #raise RuntimeError(repr(txn.crule))
                cur_txn.replace_snapshot(self.make_snapshot())
            finally:
                self.txn = cur_txn
        yield EXIT


    @TxnManager.wrap
    def _disabled_txn(self):
        self.txn = DisabledTransaction()
        try:
            yield
        finally:
            self.txn = None
            yield EXIT
