#!/usr/bin/env python

#
# This file is part of phantom_scheduler.
#
# phantom_scheduler is free software: you can redistribute it and/or modify
# it under the terms of the LGNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# phantom_scheduler is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# LGNU Lesser General Public License for more details.
#
# You should have received a copy of the LGNU Lesser General Public License
# along with phantom_scheduler.  If not, see <http://www.gnu.org/licenses/>.
#
# DTU UQ Library
# Copyright (C) 2014 The Technical University of Denmark
# Scientific Computing Section
# Department of Applied Mathematics and Computer Science
#
# Author: Daniele Bigoni
#

import sys
import socket
import getopt
import errno
import time
import threading
import SocketServer
import phantom_scheduler as ps

class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
    pass

class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
    def handle(self):
        # Get the returncode flag
        self.server.terminationFlag.returncode = int(self.request.recv(1024))
        self.server.terminationFlag.terminated = 1

class TerminationFlag():
    def __init__(self):
        # Possible statuses:
        # 0 = not terminated
        # 1 = terminated
        self.terminated = 0
        self.returncode = -1

def usage():
    print "psqrsh is part of the package phantom_scheduler."
    print "This submits a job to the <hostname> listening on <port> and returns. This is a blocking command. The command keeps polling for the end of the job. Use the option <interval> to specify the amount of sleeping seconds (type:float, default=0.01) between polls."
    print "See psqsub for a non-blocking command"
    print 'Usage: psqrsh -n <hostname> -p <portnum> [-i <interval>] <command>'

def main(argv):
    # get arguments
    INTERVAL = 0.01

    try:
        opts, args = getopt.getopt(argv,"hn:p:i:",["help=","hostname=","port=","interval="])
    except getopt.GetoptError:
        usage()
        sys.exit(2)
    for opt, arg in opts:
        if opt in ('-h','--help'):
            usage()
            sys.exit(0)
        if opt in ('-n','--hostname'):
            HOST = arg
        if opt in ('-p','--port'):
            PORT = int(arg)
        if opt in ('-i','--interval'):
            INTERVAL = float(arg)
    CMD = " ".join(args)
    
    # Look for an available port on which to listen
    # for the signal of termination of the job
    # Create the socket servers
    server_port = PORT + 1
    while True:
        try:
            server = ThreadedTCPServer(('', server_port), ThreadedTCPRequestHandler)
            break
        except socket.error as serr:
            if serr.errno != 98:
                raise serr
            # Address already in use
            server_port += 1

    terminationFlag = TerminationFlag()
    server.terminationFlag = terminationFlag
    # Start a thread with the server -- that thread will then start one
    # more thread for each request
    server_thread = threading.Thread(target=server.serve_forever)
    # Exit the server thread when the main thread terminates
    server_thread.daemon = True
    server_thread.start()
    
    # Submit the job
    while True:
        try:
            s = socket.socket( socket.AF_INET, socket.SOCK_STREAM )
            s.connect( (HOST,PORT) )
            # Declare the job blocking
            s.sendall(ps.BLOCKING)
            s.recv(1024)
            # Send the port
            s.sendall(str(server_port))
            s.recv(1024)
            # Send the job
            s.sendall(CMD)
            s.close()
            break
        except socket.error as serr:
            if serr.errno != errno.ECONNREFUSED:
                # Not the error we are looking for, re-raise
                raise serr
            time.sleep(INTERVAL)

    # Wait for the job completion and receive the return signal
    while not terminationFlag.terminated:
        time.sleep(INTERVAL)
    
    # Exit with the flag status
    sys.exit(terminationFlag.returncode)

if __name__ == "__main__":
    main(sys.argv[1:])
