'''
    Copyright (c) Supamonks Studio and individual contributors.
    All rights reserved.

    This file is part of kabaret, a python Digital Creation Framework.

    Kabaret is free software: you can redistribute it and/or modify
    it under the terms of the GNU Lesser General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    
    Redistributions of source code must retain the above copyright notice, 
    this list of conditions and the following disclaimer.
        
    Redistributions in binary form must reproduce the above copyright 
    notice, this list of conditions and the following disclaimer in the
    documentation and/or other materials provided with the distribution.
    
    Kabaret is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Lesser General Public License for more details.
    
    You should have received a copy of the GNU Lesser General Public License
    along with kabaret.  If not, see <http://www.gnu.org/licenses/>

--

    The kabaret.core.mq.async package:
        Defines the AsyncService and AsyncClient.
        blah...
        
'''
import logging
import sys
import traceback
import uuid
import time
try:
    import cPickle as pickle
except ImportError:
    import pickle

import zmq
from zmq.eventloop.zmqstream import ZMQStream
from zmq.eventloop.ioloop import IOLoop, DelayedCallback
from zmq.utils import jsonapi

from kabaret.core.utils.callbacks import weak_ref

class AsyncError(Exception):
    pass

class RemoteError(AsyncError):
    def __init__(self, ename, evalue, tb):
        super(RemoteError, self).__init__()
        self.ename = ename
        self.evalue = evalue
        self.tb = tb

    def __str__(self):
        return self.tb or "%s(%s)" % (self.ename, self.evalue)

class FutureError(AsyncError):
    pass

class Serializer(object):
    """A class for serializing/deserializing objects."""

    def loads(self, s):
        return pickle.loads(s)

    def dumps(self, o):
        return pickle.dumps(o)

    def serialize_args_kwargs(self, args, kwargs):
        """Serialize args/kwargs into a msg list."""
        return self.dumps(args), self.dumps(kwargs)

    def deserialize_args_kwargs(self, msg_list):
        """Deserialize a msg list into args, kwargs."""
        return self.loads(msg_list[0]), self.loads(msg_list[1])

    def serialize_result(self, result):
        """Serialize a result into a msg list."""
        return [self.dumps(result)]

    def deserialize_result(self, msg_list):
        """Deserialize a msg list into a result."""
        return self.loads(msg_list[0])



class AsyncBase(object):
    class MSG_TYPE:
        SYNC = b'|'
        ONEWAY = b'>'
        ASYNC = b'<>'
    
    class REPLY_TYPE:
        REPLY = b'<'
    
    class MSG_STATUS:
        SUCCESS = b'SUCCESS'
        FAIL = b'FAIL'
        
    def __init__(self, context=None):
        """Base class for async client and server.

        loop : zmq.ioloop.IOLoop instance, if None zmq.ioloop.IOLoop.instance()
        is be used.
        context : zmq.Context instance, if None zmq.Context.instance()
        is be used.
        """
        super(AsyncBase, self).__init__()
        
        self.context = context if context is not None else zmq.Context.instance()
        self.protocol = None
        self.address = None
        self.rpc_port = None
        self.rpc_socket = None
        self.rpc_stream = None
        self.ps_port = None
        self.ps_socket = None
        self.ps_stream = None
        self._serializer = Serializer()
        self.reset()

    def _create_socket(self):
        raise NotImplementedError
    
    def reset(self):
        if isinstance(self.rpc_socket, zmq.Socket):
            self.rpc_socket.close()
        if isinstance(self.ps_socket, zmq.Socket):
            self.ps_socket.close()
        self._create_socket()
        self.address = None
        self.rpc_port = None
        self.ps_port = None
        
    def _receive_rpc(self, msg_list):
        raise NotImplementedError
    
class AsyncService(AsyncBase):    
    def __init__(self, context=None, loop=None):
        # setup loop before calling the base since
        # we need it in _create_socket (called by the
        # base)
        self.loop = loop if loop is not None else IOLoop.instance()
        super(AsyncService, self).__init__(context)

    def _create_socket(self):
        self.rpc_socket = self.context.socket(zmq.ROUTER)
        self.rpc_stream = ZMQStream(self.rpc_socket, self.loop)
        self.rpc_stream.on_recv(self._receive_rpc)

        self.ps_socket = self.context.socket(zmq.PUB)

    def bind(self, address, port, protocol='tcp'):
        '''
        Binds the rpc socket to this address and port, and binds 
        the pub/sub socket to a random port on this address.
        '''
        self.protocol = protocol
        self.address = address
        self.rpc_port = port

        self.rpc_socket.bind('%s://%s:%i'%(protocol, address, port,))
        self.ps_port = self.ps_socket.bind_to_random_port('%s://%s'%(protocol, address))
    
    def bind_to_random(self, address, protocol='tcp'):
        self.protocol = protocol
        self.address = address

        self.rpc_port = self.rpc_socket.bind_to_random_port('%s://%s'%(protocol, address,))
        self.ps_port = self.ps_socket.bind_to_random_port('%s://%s'%(protocol, address))
        
    def _build_rpc_reply(self, dealer_id, msg_id, msg_status, frames):
        '''
        Build a rpc reply message for status and data.
        '''
        return [
            dealer_id, self.REPLY_TYPE.REPLY, 
            msg_id, msg_status 
        ] + frames


    def _get_rpc_handler(self, op):        
        raise NotImplementedError
    
    def _receive_rpc(self, msg_list):
        '''
        Handle a rpc call from a client
        '''
        print "_receive_rpc ----------------"
        # We dont support pipelined call so the 
        # dealer id is always a single value:
        try:
            #dealer_id, msg_type, msg_id, op, *data = msg_list
            dealer_id, msg_type, msg_id, op = msg_list[:4]
            data = msg_list[4:]
        except ValueError:
            raise AsyncError('Unable to decode rpc message msg_list: %r'%(msg_list,))
        args, kwargs = self._serializer.deserialize_args_kwargs(data)

        # Find and call the op handler.
        handler = None
        try:
            handler = self._get_rpc_handler(op)
        except:
            pass
        if handler is None:
            logging.error('Unknown RPC operation: %s'%(op,))
            try:
                raise AttributeError('No such operation %r'%(op,))
            except:
                self._send_error(dealer_id, msg_id)
                return 

        try:
            result = handler(*args, **kwargs)
        except Exception:
            if msg_type != self.MSG_TYPE.ONEWAY:
                self._send_error(dealer_id, msg_id)
                return
            
        if msg_type == self.MSG_TYPE.ONEWAY:
            return

        try:
            data_list = self._serializer.serialize_result(result)
        except Exception:
            self._send_error(dealer_id, msg_id)
        else:
            reply = self._build_rpc_reply(
                dealer_id, msg_id,
                self.MSG_STATUS.SUCCESS, data_list
            )
            print 'reply:'
            print reply
            self.rpc_stream.send_multipart(reply)
                
    def _send_error(self, dealer_id, msg_id):
        """Send an error reply."""
        etype, evalue, tb = sys.exc_info()
        error_dict = {
            'ename' : str(etype.__name__),
            'evalue' : str(evalue),
            'tb' : traceback.format_exc(tb)
        }
        data_list = [jsonapi.dumps(error_dict)]
        reply = self._build_rpc_reply(
            dealer_id, msg_id,
            self.MSG_STATUS.FAIL, data_list
        )
        self.rpc_stream.send_multipart(reply)
 
    def _get_ps_port(self):
        return self.ps_port
    
    def publish(self, topic, message):
        if not self.ps_port:
            raise RuntimeError('bind must be called first')
        data = self._serializer.dumps(message)
        self.ps_socket.send_multipart([topic, data])
        
    def start(self):
        """Start the event loop for this RPC service."""
        print 'Server Starting'
        self.loop.start()

class Future(object):
    def __init__(self, op, args, kwargs):
        self.op = op
        self.args = args
        self.kwargs = kwargs
        self.timeout = 0
        
        self.result = None
        self.done = False
        self.timedout = False
        self.failed = False
        self.err = None
        
        self._result_cb = None
        self._err_cb = None
        self._to_cb = None
    
        self._cretime = time.time()
    
    def _check(self):
        '''
        Returns True if this future is accomplished (either
        Returns False if this future is still to come. 
        done or handled timed out).
        '''
        if self.done or self.timedout or self.failed:
            return True
        if not self.timeout:
            return self.done
        if time.time()-self._cretime > self.timeout:
            self.timedout = True
            if self._to_cb:
                self._to_cb()
        return self.timedout or self.done
    
    def _set_result(self, result):
        self.result = result
        self.done = True
        if self._result_cb is not None:
            try:
                self._result_cb(result)
            except:
                logging.error(traceback.format_exc())
                
    def _set_err(self, ename, evalue, tb):
        self.failed = True
        self.err = (ename, evalue, tb)
        if self._err_cb is not None:
            self._err_cb(ename, evalue, tb)

    def _assert_cb(self, cb):
        if self.done: 
            raise FutureError('Too late to set a callback')
        if not callable(cb):
            raise FutureError('A callback must be callable')

    def set_result_callback(self, cb):
        if cb is None:
            self._result_cb = None
            return
        self._assert_cb(cb)
        self._result_cb = weak_ref(cb)
        
    def set_error_callback(self, cb):
        if cb is None:
            self._err_cb = None
            return
        self._assert_cb(cb)
        self._err_cb = weak_ref(cb)
        
    def set_timeout_callback(self, cb):
        if cb is None:
            self._to_cb = None
            return
        self._assert_cb(cb)
        self._to_cb = weak_ref(cb)
        
    def reset_time(self):
        self._cretime = time.time()
        
class AsyncClient(AsyncBase):
    def __init__(self, context=None):
        super(AsyncClient, self).__init__(
            context=context
        )        
        self.default_result_callback = None
        self.default_err_callback = None
        self.default_timeout_callback = None
        self.default_future_timeout = 10
        
        self._futures = {}

    def _create_socket(self):
        self.rpc_socket = self.context.socket(zmq.DEALER)
        self.rpc_socket.setsockopt(zmq.IDENTITY, bytes(uuid.uuid4()))

        self.rpc_stream = ZMQStream(self.rpc_socket, None)
        self.rpc_stream.on_recv(self._receive_rpc_future)
        self.rpc_stream.stop_on_send()

        self.ps_socket = self.context.socket(zmq.SUB)

        self.ps_stream = ZMQStream(self.ps_socket, None)
        self.ps_stream.on_recv(self._receive_pub)
        self.ps_stream.stop_on_send()

    def connect(self, address, port, protocol='tcp'):
        '''
        Connects to the rpc on address:port
        '''
        self.protocol = protocol
        self.address = address
        self.rpc_port = port
        self.rpc_socket.connect('%s://%s:%i'%(protocol, address, port,))

    def subscribe(self, topic=""):
        if self.address is None:
            raise RuntimeError('connect must be called first')
        
        if self.ps_port is None:
            self.ps_port = self.sync('_get_ps_port')
            self.ps_socket.connect('%s://%s:%s'%(self.protocol,self.address,self.ps_port))
        
        self.ps_socket.setsockopt(zmq.SUBSCRIBE, topic)
    
    def unsubscribe(self, topic=""):
        if None in (self.address, self.ps_port):
            raise RuntimeError('connect and subscribe must be called first')
        
        self.ps_socket.setsockopt(zmq.UNSUBSCRIBE, topic)

    def _build_request(self, method, msg_type, args, kwargs):
        msg_id = bytes(uuid.uuid4())
        method = bytes(method)
        msg_list = [msg_type, msg_id, method]
        data_list = self._serializer.serialize_args_kwargs(args, kwargs)
        msg_list.extend(data_list)
        return msg_id, msg_list

    def _receive_rpc(self, msg_list):
        '''
        Handle a rpc reply from the server.
        '''
        if msg_list[0] != self.REPLY_TYPE.REPLY:
            logging.error('Unexpected message received from server: %r'%(msg_list,))
            return
        
        try:
            #msg_id, msg_status, *data = msg_list[1:]
            msg_id, msg_status = msg_list[1:3]
            data = msg_list[3:]
        except ValueError:
            raise AsyncError('Unable to decode rpc reply msg_list: %r'%(msg_list,))
        
        return msg_id, msg_status, data
    
    def _receive_rpc_future(self, msg_list):
        '''
        Handle a async rpc reply from the server.
        '''
        msg_id, msg_status, data = self._receive_rpc(msg_list)
        
        future = self._futures.pop(msg_id, None)
        if future is None:
            # silently pass?
            logging.warn('Did not find the future for this rpc reply: %r'%(msg_list,))
            return
        
        if msg_status == self.MSG_STATUS.SUCCESS:
            result = self._serializer.deserialize_result(data)
            future._set_result(result)
        else:
            future._set_err(**jsonapi.loads(data[0]))
            
    def _send(self, op, msg_type, args, kwargs):
        '''
        Sends a rpc order on the connected server.

        op: the name of the command to send
        msg_type: one of AsyncBase.MSG_TYPE
        args, kwargs: the arguments and keyword arguments
        of the order.
        '''
        if not self.address:
            raise RuntimeError('connect must be called first')

        # if async, prepare the future before sending
        # the order.
        # (or the reply may arrive before the future is
        # initialized and registred
        future = None
        if msg_type == self.MSG_TYPE.ASYNC:
            future = Future(op, args, kwargs)
            future.timeout = self.default_future_timeout
            future.set_result_callback(self.default_result_callback)

        # send the order message
        send_id, msg_list = self._build_request(op, msg_type, args, kwargs)
        if future is not None:
            self._futures[send_id] = future
        self.rpc_socket.send_multipart(msg_list)

        # if sync, get the response:
        if msg_type == self.MSG_TYPE.SYNC:
            msg_list = self.rpc_socket.recv_multipart()
            msg_id, msg_status, data = self._receive_rpc(msg_list)
            if send_id != msg_id:
                raise AsyncError(
                    'Rpc reply does not match request?!? (%r != %r):\n%r'%(
                        send_id, msg_id, msg_list
                    )
                )
            if msg_status == self.MSG_STATUS.FAIL:
                err_descr = jsonapi.loads(data[0])
                raise RemoteError(**err_descr)
            result = self._serializer.deserialize_result(data)
            return result

        # if one way, nothing more to do:
        if msg_type == self.MSG_TYPE.ONEWAY:
            return None
        
        # Finalize the future and return it
        future.reset_time()
        return future
    
    def one_way(self, op, *args, **kwargs):
        return self._send(op, self.MSG_TYPE.ONEWAY, args, kwargs)
        
    def sync(self, op, *args, **kwargs):
        return self._send(op, self.MSG_TYPE.SYNC, args, kwargs)
        
    def async(self, op, *args, **kwargs):
        return self._send(op, self.MSG_TYPE.ASYNC, args, kwargs)

    def _receive_pub(self, msg_list):
        print 'RECEIVE PUB', msg_list
        
    def tick(self):
        '''
        Updates the rpc stream, dispatches results to 
        futures, updates the subscription stream, and
        returns the number of futures left to solve.
        
        '''
        
        self.rpc_stream.flush(flag=zmq.POLLIN, limit=100)
        check = Future._check # bring to local scope for speed
        self._futures = dict([
            (fid, f) for fid, f in self._futures.items()
            if not check(f)
        ])
        self.ps_stream.flush(flag=zmq.POLLIN, limit=100)
        return len(self._futures)
        

 
if __name__ == '__main__':
    args = sys.argv[1:]
    port = 1001
    if args[0] == 's':
        service = AsyncService()
        service.bind('*', port)
        service.start()
    else:
        client = AsyncClient()
        client.connect('192.168.1.108', port)
        alive = True
        context = {'c':client}
        result = None
        while alive:
            i = raw_input('Client>')
            if not i.strip():
                print 'update' 
                client.tick()
                continue
            if i.strip() == 'q':
                alive = False
            else:
                try:
                    result = eval(i, context)
                    print ' ', i
                    context['_'] = result
                except SyntaxError:
                    exec(i, context)
                except:
                    traceback.print_exc()
                else:
                    if result is not None:
                        print '>', result
                        
            