__author__ = "ardevelop"

import tornado.ioloop
import psycopg2, psycopg2.extras, psycopg2.extensions

NamedTupleCursor = psycopg2.extras.NamedTupleCursor


def fetchone_callback(callback):
    def _callback(cursor):
        callback(cursor.fetchone())

    return _callback


def fetchall_callback(callback):
    def _callback(cursor):
        callback(cursor.fetchall())

    return _callback


def fetchone_attr_callback(attr_name, callback):
    def _callback(cursor):
        callback(getattr(cursor.fetchone(), attr_name))

    return _callback


def fetchall_attr_callback(attr_name, callback):
    def _callback(cursor):
        callback([getattr(row, attr_name) for row in cursor.fetchall()])

    return _callback


class PostgreAsyncConnection():
    def __init__(self, dsn, database=None, username=None, password=None, host=None, port=None,
                 connection_factory=None, cursor_factory=None, init_chain=None, ioloop=None, **kwargs):
        self._ioloop = ioloop or tornado.ioloop.IOLoop.instance()
        self._init_chain = list(init_chain) if init_chain else []
        self._callback = self._connect_callback
        self._callback_arg = self
        self._cursor_factory = cursor_factory or psycopg2.extras.DictCursor
        self._conn = psycopg2.connect(dsn, database, username, password, host, port,
                                      connection_factory, cursor_factory, True, **kwargs)
        self._fd = self._conn.fileno()
        self._ioloop.add_handler(self._fd, self._io_callback, tornado.ioloop.IOLoop.WRITE)

    def _connect_callback(self, conn):
        def _callback():
            self._ioloop.add_callback(self._connect_callback, (conn,))

        if self._init_chain:
            self._init_chain.pop(0)(self, _callback)

    def _io_callback(self, fd, events):
        try:
            state = self._conn.poll()
        except psycopg2.OperationalError:
            self._ioloop.remove_handler(fd)

        except (psycopg2.Warning, psycopg2.Error), ex:
            self._ioloop.remove_handler(fd)
            raise ex
        else:
            if psycopg2.extensions.POLL_OK == state:
                self._ioloop.remove_handler(fd)
                self._callback(self._callback_arg)
            elif psycopg2.extensions.POLL_READ == state:
                self._ioloop.update_handler(fd, tornado.ioloop.IOLoop.READ)
            elif psycopg2.extensions.POLL_WRITE == state:
                self._ioloop.update_handler(fd, tornado.ioloop.IOLoop.WRITE)
            else:
                raise psycopg2.OperationalError("psycopg2.connection.poll() returned unhandled state \"%s\"" % state)

    def execute(self, statement, args=(), cursor_factory=None, callback=None):
        cursor = self._conn.cursor(cursor_factory=cursor_factory or self._cursor_factory)
        cursor.execute(statement, args)

        if callback:
            self._callback = callback
            self._callback_arg = cursor

        self._ioloop.add_handler(self._fd, self._io_callback, tornado.ioloop.IOLoop.WRITE)

    def callproc(self, name, args=(), cur_factory=None, callback=None):
        cursor = self._conn.cursor(cursor_factory=cur_factory)
        cursor.callproc(name, args)

        if callback:
            self._callback = callback
            self._callback_arg = cursor

        self._ioloop.add_handler(self._fd, self._io_callback, tornado.ioloop.IOLoop.WRITE)

    def busy(self):
        return self._conn.isexecuting()

    def closed(self):
        return self._conn.closed > 0


class PostgreAsyncClient(object):
    def __init__(self, dsn=None, database=None, username=None, password=None, host=None, port=None, pool_size=1,
                 ioloop=None, connection_factory=None, cursor_factory=None, hstore=False, init=None, reconnect_after=5,
                 **kwargs):

        self._ioloop = ioloop or tornado.ioloop.IOLoop.instance()
        self._pool = set()
        self._reconnect_after = reconnect_after

        init_chain = []

        if hstore:
            init_chain.append(self._register_hstore)

        if init:
            if isinstance(init, (list, tuple, set)):
                init_chain.extend(init)
            else:
                init_chain.append(init)

        init_chain.append(self._initialized_callback)

        def connect():
            PostgreAsyncConnection(dsn, database, username, password, host, port, connection_factory, cursor_factory,
                                   init_chain, self._ioloop, **kwargs)

        for i in xrange(pool_size):
            connect()

        self._connect = connect

    def _register_hstore(self, conn, callback):
        def _register_hstore_callback(cursor):
            oid, array_oid = cursor.fetchone()
            psycopg2.extras.register_hstore(None, False, True, oid, array_oid)
            callback()

        conn.execute("SELECT 'hstore'::regtype::oid, 'hstore[]'::regtype::oid",
                     cursor_factory=psycopg2.extras.DictCursor, callback=_register_hstore_callback)

    def _initialized_callback(self, conn, callback):
        self._pool.add(conn)

    def _aquire_connection(self):
        for conn in self._pool:
            if not conn.busy():
                return conn

    def _disconnect(self, conn):
        self._pool.remove(conn)
        self._ioloop.add_timeout(time.time() + self._reconnect_after, self._connect)

    def execute(self, statement, args=(), cur_factory=None, callback=None):
        conn = self._aquire_connection()
        if conn:
            conn.execute(statement, args, cur_factory, callback)
        else:
            self._ioloop.add_callback(self.execute, statement, args, cur_factory, callback)

    def callproc(self, name, args=(), cur_factory=None, callback=None):
        conn = self._aquire_connection()
        if conn:
            conn.callproc(name, args, cur_factory, callback)
        else:
            self._ioloop.add_callback(self.execute, name, args, cur_factory, callback)

    def mogrify(self, statement, args=()):
        conn = self._aquire_connection()
        psycopg2._psycopg.cursor(conn._conn).mogrify(statement, args)


if "__main__" == __name__:
    import time

    db = None

    def init():
        global db
        db = PostgreAsyncClient(None, "commerce", "commerce", "H2Rekuh5ma38cHAS", "ar.md", 5433,
                                hstore=True, pool_size=4, cursor_factory=NamedTupleCursor)

        db.execute("SELECT pg_sleep(1)", callback=done)
        db.execute("SELECT datname FROM pg_database", callback=done)

    def done(cursor):
        for row in cursor.fetchall():
            print row

        time.sleep(1)
        try:
            db.execute("SELECT datname FROM pg_database", callback=done)
        except Exception, ex:
            print ex, type(ex)

    ioloop = tornado.ioloop.IOLoop().instance()
    ioloop.add_callback(init)
    ioloop.start()