#!/usr/bin/env python
"""
NetLogger socket server command-line program.

Features:
 * listen on either a TCP or UDP socket.
"""
__author__ = "Dan Gunter <dkgunter@lbl.gov>"
__version__ = "$Revision: 423 $"
__rcsid__ = "$Id: netlogd 423 2007-12-17 08:59:56Z dang $"

import asyncore
import logging
import optparse
import os
import re
import signal
import socket
import string
import sys
import time
import types

# Global is-debugging-on flag
g_debug = False
def debug(s): logging.getLogger().debug(s)

try:
    # Which signals are caught (for clean shutdown)
    SHUTDOWN_SIGNALS = ( signal.SIGINT, signal.SIGUSR1, signal.SIGUSR2)
    # Which signals are caught for flush
    FLUSH_SIGNALS = (signal.SIGHUP,)
except AttributeError:
    SHUTDOWN_SIGNALS = ()
    FLUSH_SIGNALS = ()
    
# Server object, needs to be global for signal handling
g_server = None
g_outputs = None

DEFAULT_PORT = 14380

def shutdown(signo, frame):
    """shutdown(signo:int, frame:signal.frame obj) -> None
    
    Clean shutdown.
    Called from signal handlers.
    """
    logging.info("Server is shutting down..")
    if g_server:
        g_server.close()
    if g_outputs:
        g_outputs.close()
    sys.exit(0)

def flushall(signo,frame):
    if g_outputs:
        g_outputs.flush()        
#
## Exception classes
#

class NetlogdException(Exception): pass
class FormatException(NetlogdException): pass
class BadFileSizeException(Exception):
    def __init__(self,sz):
        Exception.__init__(self, "%s does not represent a valid file size. "
                           "For a 2GB size, allowable formats include: "
                           "2G, 2gb, 2000Mb, 2000000KB, 2000000000." % sz)
                           
#
## Utility functions
#

def parse_size(size):
    """parse_size(size:str) -> long
    
    Parse a string representation of a file size with
    units: GB or G, MB or M, KB or K, or no unit (bytes)
    and return the number of bytes this represents.
    Case-insensitive.
    """
    m = re.match("\d+[kKmMgG]?[bB]?", size)
    if m is None or m.end() != len(size):
        raise BadFileSizeException(size)
    size = size.lower()
    # strip trailing bytes symbol, if any
    if size[-1] == 'b':
        size = size[:-1]
    # look for units
    multiplier = {'k':1000,'m':1000000,'g':1000000000}.get(size[-1], 1)
    if multiplier != 1:
        return int(size[:-1]) * multiplier
    else:
        return int(size)

#
## Socket server classes
#

class TCPServer(asyncore.dispatcher):
    """TCP socket server
    """
    def __init__(self, port, callback,report_cb):
        asyncore.dispatcher.__init__(self)
        self._callback, self._reportcb = callback, report_cb
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.set_reuse_addr()
        self.bind(('', port))
        self.listen(5)
        
    def handle_accept(self):
        conn, addr = self.accept()
        logging.info("accepted connection from %s:%d" % addr)
        channel = InputChannel(conn, addr, self._callback,self._reportcb)
        channel.add_channel()

    def handle_close(self):
        logging.info("Closing server socket")
        self.close()

class Reporter:
    """Report using the provided callback every 'n' items.
    """
    def __init__(self, cb, n=65536):
        if cb is None:
            self.report = self._nullReport
        else:
            self._rcb = cb
            self._interval = n
            self._n = 0
            self.report = self._report
            self._t = time.time()
        
    def _report(self,num):
        self._n += num
        if self._n >= self._interval:
            elapsed = time.time() - self._t
            self._rcb(elapsed)
            self._t, self._n = time.time(), 0

    def _nullReport(self,num):
        pass
    
class InputChannel(asyncore.dispatcher,Reporter):
    """
    Read bytes, and send to onRead callback.
    """
    # Read block size
    block_size = 64*1024
    # Throughput report size
    report_size = 64*1024
    
    def __init__ (self, sock, addr, onRead, onReport):
        asyncore.dispatcher.__init__(self, sock)
        Reporter.__init__(self,onReport)
        self._callback = onRead
        self._buf = ""
        self._addr = addr

    def writable(self):
        return False

    def handle_close(self):
        logging.info("closing connection with client %s:%d" % self._addr)
        #self.handle_read()
        self.close()
       
    def handle_read(self):
        data = self.recv(self.block_size)
        if len(data) == 0:
            return
        self._buf += data
        start = sendRecordsToCallback(self._buf, self._callback)
        # Shorten buffer to unsent portion
        if start > 0:
            self._buf = self._buf[start:]
        self.report(len(data))


def sendRecordsToCallback(buf, cb):
    SEP = '\n'
    start = 0
    while 1:           
        end = buf.find(SEP,start,len(buf))
        if end == -1:
            break
        cb(buf[start:end+1]) # include separator
        start = end + 1
    return start
    
class UDPServer(asyncore.dispatcher,Reporter):
    """UDP socket server
    """
    MAX_PACKET_SIZE = 65536
    
    def __init__(self, port, callback, report_cb):
        asyncore.dispatcher.__init__(self)
        Reporter.__init__(self,report_cb)
        self._readcb = callback
        self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.set_reuse_addr()
        self.bind(('', port))
        self.add_channel()
        
    def handle_read(self):
        try:
            data, client_addr = self.socket.recvfrom(self.MAX_PACKET_SIZE)
            self._readcb(data)
            self.report(len(data))
        except socket.timeout:
            pass

    def handle_accept(self):
        pass

    def handle_connect(self):
        pass
    
    def handle_close(self):
        pass

    def writable(self):
        return False

#
## Output classes
#

class Output:
    def __init__(self, url, flush=False, fmt=None, rollover=None,
                 add_rcvtime=False):
        if url == '-':
            self.url = "{stdout}"
            self._output = sys.stdout
        elif url == '&':
            self.url = "{stderr}"
            self._output = sys.stderr
        else:
            self.url = url
            self._output = file(url,'a')
        self._report_rec, self._report_bytes = 0,0
        self._trans = string.maketrans(' \t','__')
        # flush flag
        self._flush = flush
        # rollover
        if rollover is not None:
            self._rollover = rollover
            self._bytes = 0
        else:
            self._rollover = None

    def report(self,sec):
        kb = (self._report_bytes/1024)/sec
        rec = self._report_rec/sec
        logging.debug("%s throughput: "
                      "%d KB/sec, %d records/sec" % (self.url,kb,rec))
        self._report_rec, self._report_bytes = 0,0

    def write(self, buf):
        """
        Write buffer to output.
        """
        self._output.write(buf)
        if self._flush:
            self.flush()
        if self._rollover:
            self._bytes += len(buf)
            if self._bytes > self._rollover:
                self.roll_over()
                self._bytes = 0
        self._report_rec += 1
        self._report_bytes += len(buf)

    def flush(self): self._output.flush()
    
    def roll_over(self):
        """
        Rename current file and re-open original (file) URL.
        """
        time_str = "%04d-%02d-%02dT%02d:%02d:%02d" % \
                 time.gmtime(time.time())[0:6]
        filename = self._output.name
        rolled_filename = filename + '.' + time_str
        logging.debug("rolling %s to %s" % (filename, rolled_filename))
        self._output.close()
        os.rename(filename, rolled_filename)
        self._output = nllite.urlfile(filename)
        
    def close(self):
        self._output.close()
            
class MultiOutput:
    def __init__(self,nl_action=False):
        self._outs = []
        self._last_report_time = time.time()
        # rpc/action field
        self._action = nl_action

    def add(self,outp):
        self._outs.append(outp)
        
    def write(self, buf):
        for o in self._outs:
            o.write(buf)
                
    def report(self,nbytes):
        report_time = time.time()
        sec = report_time - self._last_report_time
        if len(self._outs) > 1:
            tput = (nbytes/1024) / sec
            logging.debug("Overall throughput: %d KB/sec" % tput)
        for o in self._outs:
            o.report(sec)
        self._last_report_time = time.time()

    def flush(self):
        for o in self._outs: o.flush()
        
    def close(self):
        for o in self._outs: o.close()

def doLoop(loop_time, kill_time):
    if kill_time is None:
        asyncore.loop(loop_time)
    else:
        stop = time.time() + kill_time
        while stop > time.time():            
            asyncore.poll(loop_time)
        logging.info("As requested, netlogd is killing itself. Urk!")
        
def usage(prog,m=None):
    if m:
        print >>sys.stderr, m
    print >>sys.stderr, __usage__ % prog
    sys.exit(1)
    
def main(cmdline):
    global g_server, g_outputs, g_debug

    # Setup signal handlers
    map(lambda s: signal.signal(s,shutdown), SHUTDOWN_SIGNALS)
    map(lambda s: signal.signal(s,flushall), FLUSH_SIGNALS)
    # setup options
    usage = "%prog [options] [-h]"
    parser = optparse.OptionParser(usage=usage,version="2.0")
    parser.add_option('-b','--fork',action="store_true",dest="fork",
                      default=False,
                      help="fork into the background after starting up")
    parser.add_option('-f','--flush',action="store_true",dest="flush",
                      default=True,
                      help="flush all outputs after each record")
    parser.add_option('-k','--kill',dest="seppuku",type="string",
                     default="",metavar="TIME",
                     help="Kill self after some time. Time can be given "
                     "with units 's','m', or 'h' for seconds, minutes or hours. "
                     "Default units are minutes ('m')")
    parser.add_option('-o',"--output", action="append", dest="urls",
                            default=[], metavar="URL",
                            help="Output file(s), repeatable (default=stdout)")
    parser.add_option('-p','--port',action="store",dest="port",type="int",
                      default=DEFAULT_PORT, metavar="PORT",
                      help="port number (default=%d)" % DEFAULT_PORT)
    parser.add_option('-q','--quiet',action="store_true",dest="quiet",
                      default=False,
                      help="print nothing to stderr, overrides '-v'")
    parser.add_option('-r','--rollover',action="store",dest="rollover",
                      default=None, metavar="SIZE",
                      help="roll over files at given file size (units allowed)")
    parser.add_option('-U','--udp',action='store_true',dest='udpmode',
                      default=False,
                      help="listen on a UDP instead of TCP socket")
    parser.add_option('-v','--verbose',action="store_true",dest="verbose",
                      default=False,
                      help="verbose mode (report throughput)")
    # parse options
    options, args = parser.parse_args(cmdline[1:])
    # sanity check
    if len(options.urls) == 0:
        options.urls.append('-')
    # logging levels
    root_logger = logging.getLogger()
    if options.quiet:
        root_logger.setLevel(logging.info)
    elif options.verbose:
        root_logger.setLevel(logging.DEBUG)
        g_debug = True
    else:
        root_logger.setLevel(logging.WARN)
    # rollover
    rollover = None
    if options.rollover:
        try:
            rollover = parse_size(options.rollover)
        except BadFileSizeException,E:
            parser.error("%s" % E)
    # create output object(s)
    outp = MultiOutput()
    for oname in options.urls:
        outp.add(Output(oname, options.flush, rollover=rollover))
    g_outputs = outp
    # seppuku (self-kill)
    if options.seppuku == "":
        kill_time = None
    else:
        num,unit = None,None
        try:
            if len(options.seppuku) == 1 or \
                   options.seppuku[-1] not in ('s','m','h'):
                num,unit = float(options.seppuku),'m'
            else:
                num,unit = float(options.seppuku[:-1]), options.seppuku[-1]
        except ValueError:
            parser.error("Bad value for self-kill time")
        kill_time = num * {'s':1,'m':60,'h':3600}[unit]
    # init server
    if options.verbose:
        report_fn = outp.report
    else:
        report_fn = None
    if options.udpmode:
        g_server = UDPServer(options.port, outp.write, report_fn)
        transport = "UDP"
    else:
        g_server = TCPServer(options.port, outp.write, report_fn)
        transport = "TCP"
    logging.info("Listening on %s port %d" % (transport,options.port))
    # run server -- forever!
    if options.fork:
        pid = os.fork()
        if pid == 0:
            doLoop(0.1, kill_time)
        else:
            try:
                pidfile = open("netlogd.pid","w")
                pidfile.write("%d\n" % pid)
                pidfile.close()
                logging.info("Placed child PID %d in 'netlogd.pid'" % pid)
            except IOError:
                logging.warn("Cannot open PID file 'netlogd.pid' = %d",pid)
            sys.exit(0)
    else:
        doLoop(0.1, kill_time)

if __name__ == '__main__':
    main(sys.argv)
