#!/usr/bin/env python

from rootpy.extern.argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument('-p',"--port", type=int, dest="port", default=40000,
                  help="port that I will listening on")
args = parser.parse_args()

from multiprocessing.managers import BaseManager
from multiprocessing import Process, Queue, Manager
import time
import sys
import traceback
import logging
import socket

logger = logging.getLogger("Server")
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)-8s %(message)s')
handler.setFormatter(formatter)
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)

logger.info("initializing...")

def broadcast():

    MCAST_GRP = '224.1.1.1'
    MCAST_PORT = 5007
    hostname = socket.gethostname()
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
    sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 2)
    while True:
        sock.sendto("%s:%s"% (hostname, args.port), (MCAST_GRP, MCAST_PORT))
        time.sleep(1)

class WorkerManager(BaseManager): pass
WorkerManager.register('get_queue')

class Server(Process):

    def __init__(self, connect_queue, incoming_queue, request_queue, workers):

        self.connect_queue = connect_queue
        self.incoming_queue = incoming_queue
        self.request_queue = request_queue
        self.outgoing_queues = {}
        self.workers = workers
        super(Server, self).__init__()

    def run(self):
        
        self.connect_queue.cancel_join_thread()
        self.incoming_queue.cancel_join_thread()
        self.request_queue.cancel_join_thread()
        try:
            terminate = False
            while True:
                while not self.connect_queue.empty():
                    logger.info("checking for connection requests")
                    host, port = self.connect_queue.get()
                    # is this worker telling me that it has been terminated?
                    if port is None:
                        if not self.outgoing_queues.has_key(host):
                            logger.warning("received connection request from host %s but did not receive the port"% host)
                            logger.warning("refusing connection from %s"% host)
                            continue
                        logger.info("disconnecting from worker at host %s which has terminated"% host)
                        self.workers.remove(host)
                        #self.outgoing_queues[host][0].join_thread()
                        del self.outgoing_queues[host]
                        continue
                    manager = WorkerManager(address=(host, port), authkey='abracadabra')
                    manager.connect()
                    worker_queue = manager.get_queue()
                    #worker_queue.cancel_join_thread()
                    if self.outgoing_queues.has_key(host):
                        logger.warning("refusing connection request from host %s since I am already connected to a worker there"% host)
                        try:
                            worker_queue.put(False)
                        except: pass
                    else:
                        logger.info("connected with worker at %s which is listening on port %i"% (host, port))
                        worker_queue.put(True)
                        self.outgoing_queues[host] = (worker_queue, manager)
                        self.workers.append(host)
                while not self.incoming_queue.empty():
                    logging.info("checking for incoming job requests")
                    job = self.incoming_queue.get()
                    # poison pill
                    if job is None:
                        terminate = True
                        break
                    requesting_host = self.request_queue.get()
                    if not self.outgoing_queues.has_key(requesting_host):
                        logger.warning("ignoring job request from host that was not previously connected")
                        break
                    self.outgoing_queues[requesting_host].put(job)
                if terminate:
                    break
                time.sleep(1)
        except (KeyboardInterrupt, SystemExit): pass
        except:
            print sys.exc_info()
            traceback.print_tb(sys.exc_info()[2])
        logger.info("server is now terminating...")
        # signal to all workers that I am exiting
        for host, (queue, manager) in self.outgoing_queues.items():
            logger.info("notifying worker at host %s..."% host)
            queue.put(None)
        self.connect_queue.close()
        self.incoming_queue.close()
        self.request_queue.close()

connect_queue = Queue()
incoming_queue = Queue()
request_queue = Queue()

shared_manager = Manager()
workers = shared_manager.list()
    
class ServerManager(BaseManager): pass
ServerManager.register('get_connect_queue', callable=lambda:connect_queue)
ServerManager.register('get_request_queue', callable=lambda:request_queue)
ServerManager.register('get_queue', callable=lambda:incoming_queue)
ServerManager.register('workers', callable=lambda:workers)

server_manager = ServerManager(address=('', args.port), authkey='abracadabra')
try:
    server_manager_server = server_manager.get_server()
    logger.info("server initialized")
except:
    logger.error("failed to initialize server using port %i"% args.port)
else:
    server = Server(connect_queue, incoming_queue, request_queue, workers)
    server.start()
    broadcaster = Process(target = broadcast)
    broadcaster.daemon = True
    broadcaster.start()
    server_manager_server.serve_forever()
    incoming_queue.put(None)
    server.join()
connect_queue.close()
incoming_queue.close()
request_queue.close()
connect_queue.join_thread()
incoming_queue.join_thread()
request_queue.join_thread()
