#!/usr/bin/python -u
#
# zmqdump - dump zmq messages on a socket
#
# @author: Heinrich Hartmann <derhein@gmail.com>
#
# https://github.com/HeinrichHartmann/zmqdump
#
# License: GPLv2
#
# REMARK:
# It is important to disable stdin/stdout buffers when using 
# this script in pipes to get the expected results.
# Therefore python is called with the -u option. 
# See: http://stackoverflow.com/questions/9756137/immediate-piping-with-python
#

import argparse
import time
import sys
import zmq

VERSION = "0.1"
DATE= "2014-02-09"

DEBUG = 0

def main():
    """
    zmqdump command line script
    """
    
    global DEBUG

    parser = setup_parser()
    conf = parser.parse_args()

    if (conf.verbose):
        DEBUG = 10

    zdp = Zmqdump(conf)
    zdp.loop()

    zdp.destroy()

class Zmqdump:
    conf = None
    context = None
    socket = None
    socket_type = None

    def __init__(self, conf):
        """
        Constructor

        Configuration options are translated into socket
        configuration.
        """

        self.setup_socket(conf.socket_type)

        self.set_hwm(conf.hwm)

        self.set_pattern(conf.pattern)

        method = "bind" if conf.bind else "connect"
        self.init_connection(method, conf.endpoint)

        time.sleep(conf.delay * 0.001)


    def setup_socket(self, socket_type):
        """
        Create context and socket obejcts.
        """

        if (DEBUG): print "Setting up %s-socket" % (socket_type)

        self.context = zmq.Context()

        if (socket_type == "PUB"):
            self.socket_type = "PUB"
            self.socket = self.context.socket(zmq.PUB)

        elif (socket_type == "SUB"):
            self.socket_type = "SUB"
            self.socket = self.context.socket(zmq.SUB)

        elif (socket_type == "PUSH"):
            self.socket_type = "PUSH"
            self.socket = self.context.socket(zmq.PUSH)

        elif (socket_type == "PULL"):
            self.socket_type = "PULL"
            self.socket = self.context.socket(zmq.PULL)

        elif (socket_type == "REQ"):
            self.socket_type = "REQ"
            self.socket = self.context.socket(zmq.REQ)

        elif (socket_type == "REP"):
            self.socket_type = "REP"
            self.socket = self.context.socket(zmq.REP)

        else:
            raise Exception("socket type not supported: " + socket_type)


    def set_hwm(self, hwm):
        assert(type(hwm) is int)
        if (DEBUG): print "Setting hwm to " + str(hwm)

        self.socket.setsockopt(zmq.RCVHWM, hwm)
        self.socket.setsockopt(zmq.SNDHWM, hwm)


    def set_pattern(self, pattern):
        if (DEBUG): print "Setting subscription pattern '%s'" % pattern
        if (self.socket_type == "SUB"):
            self.socket.setsockopt(zmq.SUBSCRIBE, pattern)


    def init_connection(self, method, endpoint):
        """
        Bind/Connect socket
        """
        assert(method in ["bind", "connect"])
        assert(type(endpoint) is str and "://" in endpoint)
        if (DEBUG): print "%s socket to %s" % (method, endpoint)

        if (method == "bind"):
            self.socket.bind(endpoint)
        elif (method == "connect"):
            self.socket.connect(endpoint)


    def loop(self):
        """
        Start listning/writing to socket.
        """
        if (self.socket_type in ["PUSH", "PUB"]): 
            self.send_loop()
        elif (self.socket_type in ["PULL", "SUB"]):
            self.print_loop()
        elif (self.socket_type == "REQ"):
            self.req_loop()
        elif (self.socket_type == "REP"):
            self.rep_loop()

    def req_loop(self):
        if (DEBUG): print "Req loop."
        while (True):
            try:
                line = sys.stdin.readline()
                if not line: 
                    break
                line = line[:-1] # srip \n at the end of the line
                self.socket.send(line)

                sys.stdout.write(self.socket.recv_string())
            except KeyboardInterrupt:
                break

    def rep_loop(self):
        if (DEBUG): print "Rep loop."
        while (True):
            try:
                sys.stdout.write(self.socket.recv_string())
                line = sys.stdin.readline()
                if not line: 
                    break
                line = line[:-1] # srip \n at the end of the line
                self.socket.send(line)
            except KeyboardInterrupt:
                break
                
    def print_loop(self):
        if (DEBUG): print "Listening on socket."
        while (True):
            try:
                sys.stdout.write(self.socket.recv_string())
            except KeyboardInterrupt:
                sys.stdout.flush()
                break


    def send_loop(self):
        if (DEBUG): print "Writing to socket."
        while (True):
            try:
                line = sys.stdin.readline()
                if not line: 
                    break
                line = line[:-1] # srip \n at the end of the line
                self.socket.send(line)
            except KeyboardInterrupt:
                break
    
    def destroy(self):
        self.socket.close()
        self.context.term()
        self.context.destroy()

        
def setup_parser():
    parser = argparse.ArgumentParser(
        prog = "zmqdump",
        description = "dump zmq messages on a socket"
    )

    parser.add_argument(
        "socket_type", 
        help= "type of zmq socket.", 
        type = str,
        choices = ["SUB","PUB","PUSH","PULL","REQ","REP"]
    )

    parser.add_argument(
        "endpoint", 
        help="endpoint to listen on messages, e.g. 'tcp://127.0.0.1:5000'",
        type = str
    )
    
    parser.add_argument(
        "-d", "--delay", 
        help = "initial delay before sendig out messages",
        dest = "delay", type = int, default = 100
    )
    
    parser.add_argument(
        "-hwm", 
        help="High water mark.",
        dest="hwm", type=int, default = 1000
    )

    parser.add_argument(
        "-s", "--subscribe",
        help="subscription pattern. Only relevant for SUB sockets.",
        dest="pattern", type=str, default = ""
    )

    parser.add_argument(
        "-b", "--bind",
        help="bind socket instead of connect",
        dest="bind", default = False,
        action = "store_true"
    )

    parser.add_argument(
        "-v", "--verbose",
        help="print additional logging information",
        dest="verbose", default = False,
        action = "store_true"
    )

    parser.add_argument(
        "--version",
        action='version',
        version='%(prog)s v.' + VERSION
    )

    return parser
   
if __name__ == "__main__":
    main()


