#!/usr/bin/env python

from rootpy.extern.argparse import ArgumentParser
import multiprocessing

parser = ArgumentParser()
parser.add_argument('-n',"--nproc", type=int, dest="nproc",
                  help="maximum number of parallel jobs to run by this worker", default=multiprocessing.cpu_count())
parser.add_argument('-p',"--port", type=int, dest="worker_port", default=50001,
                  help="port that I will listen on")
args = parser.parse_args()

import socket
from multiprocessing.managers import BaseManager
from multiprocessing import Process, Queue
import sys
import traceback
import logging
import socket
import struct
import time

logger = logging.getLogger("Worker")
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)-8s %(message)s')
handler.setFormatter(formatter)
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)

logger.info("initializing...")

class ServerManager(BaseManager): pass
ServerManager.register('get_connect_queue')
ServerManager.register('get_request_queue')

class Worker(Process):

    def __init__(self, queue):

        self.queue = queue
        super(Worker, self).__init__()

    def run(self):
        
        self.queue.cancel_join_thread() 
        server_terminated = False
        connect_queue = None
        request_queue = None
        connected = False
        try:
            # attempt to discover server
            MCAST_GRP = '224.1.1.1'
            MCAST_PORT = 5007

            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            sock.bind(('', MCAST_PORT))
            mreq = struct.pack("4sl", socket.inet_aton(MCAST_GRP), socket.INADDR_ANY)

            sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
            #sock.settimeout(5)
            logger.info("looking for server on local network...") 
            response = sock.recv(10240)
            server,server_port = response.split(':')
            logger.info("found server at %s:%s"% (server, server_port))
            manager = ServerManager(address=(server, int(server_port)), authkey='abracadabra')
            logger.info("connecting with server manager...")
            manager.connect()
            logger.info("connected with server manager")
            connect_queue = manager.get_connect_queue()
            #connect_queue.cancel_join_thread()
            request_queue = manager.get_request_queue()
            #request_queue.cancel_join_thread()
            # announce my existence to the main job queue server
            logger.info("requesting connection with server")
            try:
                connect_queue.put((socket.gethostname(), args.worker_port))
            except:
                logger.error("failed to request connection with server")
                raise SystemExit
            if not self.queue.get(): # connection refused
                logger.error("connection refused by server")
                return
            connected = True
            logger.info("connected successfully with server")
            while True:
                # ask for a job
                request_queue.put(socket.gethostname())
                job = self.queue.get()
                if job is None: # server has terminated
                    logger.critical("server has terminated")
                    server_terminated = True
                    break
                print job
                """
                job.start()
                job.join()
                """
        except (KeyboardInterrupt, SystemExit): pass
        except:
            print sys.exc_info()
            traceback.print_tb(sys.exc_info()[2])
        logger.info("worker is now terminating...")
        if not server_terminated and connected:
            # tell the server that I am terminating
            logger.info("notifying the server...")
            try:
                connect_queue.put((socket.gethostname(), None))
                while not connect_queue.empty():
                    time.sleep(.5)
            except:
                logger.error("failed to notify the server")
        """
        if connect_queue is not None:
            connect_queue.close()
        if request_queue is not None:
            request_queue.close()
        """
        self.queue.close()

job_queue = Queue()

class WorkerManager(BaseManager): pass
WorkerManager.register('get_queue', callable=lambda:job_queue)

worker_manager = WorkerManager(address=('', args.worker_port), authkey='abracadabra')
try:
    worker_manager_server = worker_manager.get_server()
    logger.info("worker initialized")
except:
    logger.error("failed to initialize worker using port %i"% args.worker_port)
else:
    worker = Worker(job_queue)
    worker.start()
    worker_manager_server.serve_forever()
    job_queue.put(None)
    worker.join()
job_queue.close()
