__author__ = "ardevelop"

import tornado.ioloop
import time

import psycopg2
import psycopg2.extras
import psycopg2.extensions

from psycopg2.extras import NamedTupleCursor, DictCursor, RealDictCursor


def success_callback(callback):
    def _callback(cursor):
        callback(bool(cursor))
    return _callback


def fetchone_callback(callback):
    def _callback(cursor):
        callback(cursor.fetchone())

    return _callback


def fetchall_callback(callback):
    def _callback(cursor):
        callback(cursor.fetchall())

    return _callback


def fetchone_attr_callback(attr_name, callback):
    def _callback(cursor):
        callback(getattr(cursor.fetchone(), attr_name))

    return _callback


def fetchall_attr_callback(attr_name, callback):
    def _callback(cursor):
        callback([getattr(row, attr_name) for row in cursor.fetchall()])

    return _callback


class PostgreAsyncConnection():
    def __init__(self, dsn, database=None, username=None, password=None, host=None, port=None, pool=None,
                 connection_factory=None, cursor_factory=None, init_chain=None, ioloop=None, **kwargs):
        self._pool = pool
        self._ioloop = ioloop or tornado.ioloop.IOLoop.instance()
        self._init_chain = list(init_chain) if init_chain else []
        self._callback = self._connect_callback
        self._callback_arg = self
        self._cursor_factory = cursor_factory or psycopg2.extras.DictCursor
        self._conn = psycopg2.connect(dsn, database, username, password, host, port,
                                      connection_factory, cursor_factory, True, **kwargs)
        self._fd = self._conn.fileno()
        self._ioloop.add_handler(self._fd, self._io_callback, tornado.ioloop.IOLoop.WRITE)

    def _connect_callback(self, conn):
        def _callback():
            self._ioloop.add_callback(self._connect_callback, (conn,))

        if self._init_chain:
            try:
                self._init_chain.pop(0)(self, _callback)
            except Exception, ex:
                try:
                    self._pool._connect_error(self)
                except AttributeError:
                    pass

                raise ex

    def _io_callback(self, fd, events):
        try:
            state = self._conn.poll()
        except (psycopg2.Warning, psycopg2.Error), ex:
            self._ioloop.remove_handler(fd)
            try:
                #if ex.message.endswith("connection failed"):
                self._pool._connect_error(self)
                raise ex
                #else:
                #    self._pool._disconnect(self)
            except AttributeError:
                pass

            if self._pool is not None:
                self._callback(ex)
            else:
                raise ex
        else:
            if psycopg2.extensions.POLL_OK == state:
                self._ioloop.remove_handler(fd)
                self._callback(self._callback_arg)
            elif psycopg2.extensions.POLL_READ == state:
                self._ioloop.update_handler(fd, tornado.ioloop.IOLoop.READ)
            elif psycopg2.extensions.POLL_WRITE == state:
                self._ioloop.update_handler(fd, tornado.ioloop.IOLoop.WRITE)
            else:
                raise psycopg2.OperationalError("psycopg2.connection.poll() returned unhandled state \"%s\"" % state)

    def execute(self, statement, args=(), cursor_factory=None, callback=None):
        cursor = self._conn.cursor(cursor_factory=cursor_factory or self._cursor_factory)
        cursor.execute(statement, args)

        if callback:
            self._callback = callback
            self._callback_arg = cursor

        self._ioloop.add_handler(self._fd, self._io_callback, tornado.ioloop.IOLoop.WRITE)

    def callproc(self, name, args=(), cur_factory=None, callback=None):
        cursor = self._conn.cursor(cursor_factory=cur_factory)
        cursor.callproc(name, args)

        if callback:
            self._callback = callback
            self._callback_arg = cursor

        self._ioloop.add_handler(self._fd, self._io_callback, tornado.ioloop.IOLoop.WRITE)

    def busy(self):
        return self._conn.isexecuting()

    def closed(self):
        return self._conn.closed > 0

    def close(self):
        self._conn.close()


class PostgreAsyncClient(object):
    def __init__(self, dsn=None, database=None, username=None, password=None, host=None, port=None,
                 min_pool_size=0, max_pool_size=1, ioloop=None, connection_factory=None, cursor_factory=None,
                 hstore=False, composite_types=None, init=None, **kwargs):
        self._ioloop = ioloop or tornado.ioloop.IOLoop.instance()
        self._pool = set()
        self._wait_connections = set()
        self._min_pool_size = min_pool_size
        self._max_pool_size = max_pool_size
        self._last_connect_error = None

        init_chain = []

        if hstore:
            init_chain.append(self._register_hstore)

        if composite_types:
            #TODO: register all composite types with 2 request
            for composite_type in composite_types:
                parts = composite_type.split(".")
                if len(parts) == 1:
                    name = parts[0]
                    scheme = "public"
                else:
                    scheme = parts[0]
                    name = parts[1]
                init_chain.append(self._register_composite(name, scheme))

        if init:
            if isinstance(init, (list, tuple, set)):
                init_chain.extend(init)
            else:
                init_chain.append(init)

        init_chain.append(self._connect_success)

        def connect():
            conn = PostgreAsyncConnection(dsn, database, username, password, host, port, self, connection_factory,
                                          cursor_factory, init_chain, self._ioloop, **kwargs)
            self._wait_connections.add(conn)
            return conn

        for i in xrange(min_pool_size):
            connect()

        self._connect = connect

    def _register_hstore(self, conn, callback):
        def _register_hstore_callback(cursor):
            oid, array_oid = cursor.fetchone()
            psycopg2.extras.register_hstore(None, False, True, oid, array_oid)
            callback()

        conn.execute("SELECT 'hstore'::regtype::oid, 'hstore[]'::regtype::oid",
                     cursor_factory=psycopg2.extras.DictCursor, callback=_register_hstore_callback)

    def _register_composite(self, name, schema="public"):
        def _register_composite_task(conn, callback):
            full_name = "%s.%s" % (schema, name)
            conn.oid = None
            conn.array_oid = None

            def _got_composite_attr_types(cursor):
                attr_types = cursor.fetchall()
                caster = psycopg2.extras.CompositeCaster(name, conn.oid, attr_types, conn.array_oid, schema)
                psycopg2.extensions.register_type(caster.typecaster, conn._conn)
                psycopg2.extensions.register_type(caster.array_typecaster, conn._conn)
                callback()

            def _got_oid_and_array_oid(cursor):
                conn.oid, conn.array_oid = cursor.fetchone()
                conn.execute("""
                        SELECT attname, atttypid
                        FROM pg_type t
                        JOIN pg_namespace ns ON typnamespace = ns.oid
                        JOIN pg_attribute a ON attrelid = typrelid
                        WHERE typname = %s AND nspname = %s AND attnum > 0 AND NOT attisdropped
                        ORDER BY attnum
                    """, (name, schema), cursor_factory=psycopg2.extras.DictCursor, callback=_got_composite_attr_types)

            conn.execute("SELECT '%(name)s'::regtype::oid, '%(name)s[]'::regtype::oid" % {"name": full_name},
                         cursor_factory=psycopg2.extras.DictCursor, callback=_got_oid_and_array_oid)

        return _register_composite_task

    def _connect_success(self, conn, callback):
        self._pool.add(conn)
        self._wait_connections.remove(conn)

    def _connect_error(self, conn):
        self._last_connect_error = time.time()
        conn.close()
        try:
            self._wait_connections.remove(conn)
        except KeyError:
            pass

    def _acquire_connection(self):
        for conn in self._pool:
            if conn.closed():
                self._disconnect(conn)
                continue

            if not conn.busy():
                return conn

        if len(self._wait_connections) == 0 and len(self._pool) < self._max_pool_size:
            if self._last_connect_error and time.time() - self._last_connect_error < 5:
                raise Exception("Connect error retry in 5 seconds")

            self._connect()

    def _disconnect(self, conn):
        conn.close()

        try:
            self._pool.remove(conn)
        except KeyError:
            pass

    def execute(self, statement, args=(), cur_factory=None, callback=None):
        conn = self._acquire_connection()

        def _callback(response):
            if isinstance(response, (psycopg2.Warning, psycopg2.Error)):
                self._ioloop.add_callback(self.execute, statement, args, cur_factory, callback)
            else:
                callback(response)

        if conn:
            conn.execute(statement, args, cur_factory, _callback)
        else:
            self._ioloop.add_callback(self.execute, statement, args, cur_factory, callback)

    def callproc(self, name, args=(), cur_factory=None, callback=None):
        conn = self._acquire_connection()
        if conn:
            conn.callproc(name, args, cur_factory, callback)
        else:
            self._ioloop.add_callback(self.execute, name, args, cur_factory, callback)

    def mogrify(self, statement, args=()):
        conn = self._acquire_connection()
        psycopg2._psycopg.cursor(conn._conn).mogrify(statement, args)


if "__main__" == __name__:
    import time

    db = None

    def init():
        global db
        db = PostgreAsyncClient(None, "commerce", "commerce", "H2Rekuh5ma38cHAS", "ar.md", 5433,
                                hstore=True, pool_size=4, cursor_factory=NamedTupleCursor)

        db.execute("SELECT pg_sleep(1)", callback=done)
        db.execute("SELECT datname FROM pg_database", callback=done)

    def done(cursor):
        for row in cursor.fetchall():
            print row

        time.sleep(1)
        try:
            db.execute("SELECT datname FROM pg_database", callback=done)
        except Exception, ex:
            print ex, type(ex)

    ioloop = tornado.ioloop.IOLoop().instance()
    ioloop.add_callback(init)
    ioloop.start()