import functools
import select
import uuid
import pika
from qpid import messaging
from qpid.messaging.exceptions import LinkClosed, Empty

import minsol
import hens


class QPIDServer(minsol.Object):
    request_class = hens.Request
    response_class = hens.Response

    def init(self, acceptor_class, connection, queue_name, **other):
        self.acceptor_class = acceptor_class
        self.connection = connection
        self.queue_name = queue_name

    def start(self):
        self.init_broker()
        self.listen()

    def stop(self):
        self.receiver.close()

    def init_broker(self):
        if not self.connection.opened():
            self.connection.open()

        session = self.connection.session("rpc init")

        q = "%s; {"\
            "create:always, " \
            "node: {durable:False, x-declare:{auto-delete:True}}"\
        "}"
        q = q % self.queue_name
        session.sender(q)

        session.close()

    def listen(self):
        session = self.connection.session("rpc listen")
        self.receiver = session.receiver(self.queue_name)

        def _executed(reply_to, request, result):
            if not reply_to:
                return

            response = self.response_class(request=request, result=result)

            raw_response = response.as_raw()
            tmp_s = session.sender(reply_to)
            tmp_s.send(raw_response)
            tmp_s.close()

        while True:
            try:
                msg = self.receiver.fetch()

            except (LinkClosed, KeyboardInterrupt, select.error):
                break

            session.acknowledge(msg)

            raw_request = msg.content
            acceptor = self.acceptor_class()
            request = None

            try:
                request = self.request_class.get_from_raw(raw_request)

                callback = functools.partial(_executed, msg.reply_to, request)
                request.execute(acceptor, callback)

            except hens.Error as error:
                _executed(msg.reply_to, request, error)

        session.close()


class QPIDClient(hens.Client):
    def init(self, connection, queue_name, timeout=None, **kwargs):
        self.connection = connection
        if not self.connection.opened():
            self.connection.open()

        self.queue_name = queue_name
        self.timeout = timeout

    def _fetch_call(self, raw_request):
        session = self.connection.session("rpc call")

        response_queue_name = self.gen_response_queue_name()
        response_receiver = session.receiver(response_queue_name)

        sender = session.sender(self.queue_name)

        raw_request = messaging.Message(raw_request,
            reply_to=response_queue_name)

        sender.send(raw_request)

        try:
            msg = response_receiver.fetch(timeout=self.timeout)

        except Empty:
            raise hens.InvalidResponse()

        except (LinkClosed, KeyboardInterrupt, select.error):
            raise LinkClosed()

        raw_response = msg.content

        session.close()
        return raw_response

    def gen_response_queue_name(self):
        tmp_ident = uuid.uuid4().hex

        q = "%s.tmp-%s; {"\
            "create:always, " \
            "node: {"\
                "durable:False, "\
                "x-declare:{auto-delete:True, exclusive:True}"\
            "}}"

        q = q % (self.queue_name, tmp_ident)
        return q


class JSONQPIDServer(QPIDServer):
    request_class = hens.JSONRequest
    response_class = hens.JSONResponse


class JSONQPIDClient(QPIDClient, hens.JSONClient):
    pass


class RMQServer(minsol.Object):
    """
    RabbitMQ server
    """
    request_class = hens.Request
    response_class = hens.Response

    def init(self, acceptor_class, connection, queue_name, **other):
        self.acceptor_class = acceptor_class
        self.connection = connection
        self.queue_name = queue_name

        self.channel = None

    def start(self):
        self.init_broker()
        self.listen()

        try:
            self.channel.start_consuming()

        except KeyboardInterrupt:
            self.stop()

    def stop(self):
        self.channel.close()

    def init_broker(self):
        self.channel = self.connection.channel()

        self.channel.queue_declare(
            queue=self.queue_name,
            durable=True, auto_delete=False)

    def listen(self):
        self.channel.basic_qos(prefetch_count=1)
        self.channel.basic_consume(
            self.message_received,
            queue=self.queue_name)

    def message_received(self, ch, method, properties, body):
        self.channel.basic_ack(method.delivery_tag)

        raw_request = body
        acceptor = self.acceptor_class()
        request = None

        try:
            request = self.request_class.get_from_raw(raw_request)

            callback = functools.partial(
                self.executed,
                properties.reply_to, request)

            request.execute(acceptor, callback)

        except hens.Error as error:
            self.executed(properties.reply_to, request, error)

    def executed(self, reply_to, request, result):
        if not reply_to:
            return

        response = self.response_class(request=request, result=result)

        raw_response = response.as_raw()
        self.channel.basic_publish(
            exchange='',
            routing_key=reply_to,
            body=raw_response)


class RMQClient(hens.Client):
    def init(self, connection, queue_name, **kwargs):
        self.connection = connection

        self.queue_name = queue_name
        self.channel = self.connection.channel()
        self.channel.basic_qos(prefetch_count=1)

        self.response_channel, self.response_queue_name = \
            self.create_response_queue()

    def create_response_queue(self):
        channel = self.connection.channel()
        channel.basic_qos(prefetch_count=1)

        _result = channel.queue_declare(
            exclusive=True, auto_delete=True)

        queue_name = _result.method.queue

        return channel, queue_name

    def _fetch_call(self, raw_request):
        self.send_raw_request(raw_request)
        raw_response = self.get_raw_response()

        return raw_response

    def send_raw_request(self, raw_request):
        self.channel.basic_publish(
            exchange='',
            routing_key=self.queue_name,
            properties=pika.BasicProperties(reply_to=self.response_queue_name),
            body=raw_request)

    def get_raw_response(self):
        while True:
            method, properties, body = self.response_channel.basic_get(
                queue=self.response_queue_name, no_ack=True)

            if method:
                return body


class JSONRMQServer(RMQServer):
    request_class = hens.JSONRequest
    response_class = hens.JSONResponse


class JSONRMQClient(RMQClient, hens.JSONClient):
    pass


class RMQAsyncServer(minsol.Object):
    """
    RabbitMQ asyncRPC server
    """
    request_class = hens.Request
    response_class = hens.Response

    def init(self, acceptor_class, connection, queue_name, **other):
        self.acceptor_class = acceptor_class
        self.connection = connection
        self.queue_name = queue_name

    def start(self):
        self.connection.add_on_open_callback(self.connection_opened)

    def stop(self):
        self.channel.close()

    def connection_opened(self, connection):
        self.connection.channel(self.channel_opened)

    def channel_opened(self, channel):
        self.channel = channel

        self.channel.queue_declare(
            self.queue_declared,
            queue=self.queue_name,
            durable=True, auto_delete=False)

    def queue_declared(self, frame):
        self.channel.basic_consume(
            self.request_accepted,
            queue=self.queue_name, no_ack=True)

    def request_accepted(self, channel, frame, properties, body):
        raw_request = body
        acceptor = self.acceptor_class()
        request = None

        try:
            request = self.request_class.get_from_raw(raw_request)

            callback = functools.partial(
                self.executed,
                properties.reply_to, request)

            request.execute(acceptor, callback)

        except hens.Error as error:
            self.executed(properties.reply_to, request, error)

    def executed(self, reply_to, request, result):
        if not reply_to:
            return

        response = self.response_class(request=request, result=result)

        raw_response = response.as_raw()
        self.channel.basic_publish(
            exchange='',
            routing_key=reply_to,
            body=raw_response)


class RMQAsyncClient(hens.AsyncClient):
    def init(self, connection, queue_name,
            response_queue_pool_size=10,
            **kwargs):

        """
        :param connection: not connecten ``TornadoConnection`` object
        :type connection: pika.adapters.TornadoConnection
        :param queue_name: listen queue name
        :type queue_name: str
        """
        self.connection = connection
        self.queue_name = queue_name
        self.response_queue_pool_size = response_queue_pool_size
        self.response_queues = {}
        self.response_queues_busy = {}
        self.response_callbacks = {}
        self.wait_callbacks = []

        self.connection.add_on_open_callback(self.connection_opened)

    def connection_opened(self, connection):
        self.connection.channel(self.channel_opened)

    def channel_opened(self, channel):
        self.channel = channel

        self.init_response_queue_pool(self.response_queue_pool_inited)

    def init_response_queue_pool(self, callback):

        i = self.response_queue_pool_size

        def _gateway_created(i, rchannel, rqueue_name):
            self.consume_response_gateway(rchannel, rqueue_name)

            self.response_queues_busy[rqueue_name] = rchannel
            self.return_response_queue(rqueue_name)

            if not i:
                callback()

        while i:
            i -= 1

            created_callback = functools.partial(_gateway_created, i)
            self.create_response_gateway(created_callback)

    def create_response_gateway(self, callback):
        val = []

        def _declared(result):
            queue_name = result.method.queue
            val.append(queue_name)

            callback(*val)

        def _channed_opened(channel):
            val.append(channel)
            channel.queue_declare(
                exclusive=True, auto_delete=True,
                callback=_declared)

        self.connection.channel(_channed_opened)

    def consume_response_gateway(self, rchannel, rqueue_name):
        def _response_getted(channel, frame, props, body):
            callback = self.response_callbacks.pop(rqueue_name)
            callback(channel, frame, props, body)

        rchannel.basic_consume(
            _response_getted,
            queue=rqueue_name, no_ack=True)

    def response_queue_pool_inited(self):
        pass

    def _fetch_call(self, callback, raw_request):

        def _response_getted(channel, frame, props, body):
            callback(body)
            self.return_response_queue(frame.routing_key)

        def _response_queue_getted(rqueue_name):
            self.response_callbacks[rqueue_name] = _response_getted

            self.channel.basic_publish(
                exchange='',
                routing_key=self.queue_name,
                properties=pika.BasicProperties(reply_to=rqueue_name),
                body=raw_request)

        self.get_response_queue(_response_queue_getted)

    def get_response_queue(self, callback):
        try:
            rqueue_name, rchannel = self.response_queues.popitem()
            self.response_queues_busy[rqueue_name] = rchannel

            callback(rqueue_name)
            return

        except KeyError:
            pass

        self.wait_callbacks.append(callback)

    def return_response_queue(self, rqueue_name):
        try:
            callback = self.wait_callbacks.pop()
            callback(rqueue_name)
            return

        except IndexError:
            pass

        rchannel = self.response_queues_busy.pop(rqueue_name)
        self.response_queues[rqueue_name] = rchannel


class JSONRMQAsyncServer(RMQAsyncServer):
    request_class = hens.JSONRequest
    response_class = hens.JSONResponse


class JSONRMQAsyncClient(RMQAsyncClient, hens.JSONClient):
    pass
