#!/usr/bin/env python
"""
NetLogger socket server command-line program for listening on either a
TCP or UDP socket.
"""
__author__ = "Dan Gunter <dkgunter@lbl.gov>"
__version__ = "$Revision: 23923 $"
__rcsid__ = "$Id: netlogd 23923 2009-09-18 22:42:26Z ksb $"

import asyncore
import os
import re
import signal
import socket
import string
import sys
import time
#
from netlogger.nllog import OptionParser, get_logger, DoesLogging

try:
    # Which signals are caught (for clean shutdown)
    SHUTDOWN_SIGNALS = ( signal.SIGINT, signal.SIGUSR1, signal.SIGUSR2)
    # Which signals are caught for flush
    FLUSH_SIGNALS = (signal.SIGHUP,)
except AttributeError:
    SHUTDOWN_SIGNALS = ()
    FLUSH_SIGNALS = ()

# Server object, needs to be global for signal handling
g_server = None
g_outputs = None

DEFAULT_PORT = 14380

def shutdown(signo, frame):
    """shutdown(signo:int, frame:signal.frame obj) -> None

    Clean shutdown.
    Called from signal handlers.
    """
    log = get_logger(__file__)
    log.debug("shutdown.start")
    if g_server:
        g_server.close()
    if g_outputs:
        g_outputs.close()
    log.debug("shutdown.end")
    log.info("end", status=0)
    sys.exit(0)

def flushall(signo, frame):
    if g_outputs:
        g_outputs.flush()
#
## Exception classes
#

class NetlogdException(Exception): pass
class FormatException(NetlogdException): pass
class BadFileSizeException(Exception):
    def __init__(self, sz):
        Exception.__init__(self, "%s does not represent a valid file size. "
                           "For a 2GB size, allowable formats include: "
                           "2G, 2gb, 2000Mb, 2000000KB, 2000000000." % sz)

#
## Utility functions
#

def parse_size(size):
    """parse_size(size:str) -> long

    Parse a string representation of a file size with
    units: GB or G, MB or M, KB or K, or no unit (bytes)
    and return the number of bytes this represents.
    Case-insensitive.
    """
    m = re.match("\d+[kKmMgG]?[bB]?", size)
    if m is None or m.end() != len(size):
        raise BadFileSizeException(size)
    size = size.lower()
    # strip trailing bytes symbol, if any
    if size[-1] == 'b':
        size = size[:-1]
    # look for units
    multiplier = {'k':1000, 'm':1000000, 'g':1000000000}.get(size[-1], 1)
    if multiplier != 1:
        return int(size[:-1]) * multiplier
    else:
        return int(size)

#
## Socket server classes
#

class TCPServer(asyncore.dispatcher, DoesLogging):
    """TCP socket server
    """
    def __init__(self, port, callback, report_cb):
        DoesLogging.__init__(self, name='netlogd.TCPServer')
        asyncore.dispatcher.__init__(self)
        self._callback, self._reportcb = callback, report_cb
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.set_reuse_addr()
        self.bind(('', port))
        self.listen(5)

    def handle_accept(self):
        conn, addr = self.accept()
        self.log.info("connection.start", host=addr[0], port=addr[1])
        channel = InputChannel(conn, addr, self._callback, self._reportcb)
        channel.add_channel()

    def handle_close(self):
        self.log.info("connection.end")
        self.close()

class Reporter:
    """Report using the provided callback every 'n' items.
    """
    def __init__(self, cb, n=65536):
        if cb is None:
            self.report = self._nullReport
        else:
            self._rcb = cb
            self._interval = n
            self._n = 0
            self.report = self._report
            self._t = time.time()

    def _report(self, num):
        self._n += num
        if self._n >= self._interval:
            elapsed = time.time() - self._t
            self._rcb(elapsed)
            self._t, self._n = time.time(), 0

    def _nullReport(self, num):
        pass

class InputChannel(asyncore.dispatcher, Reporter, DoesLogging):
    """
    Read bytes, and send to onRead callback.
    """
    # Read block size
    block_size = 64*1024
    # Throughput report size
    report_size = 64*1024

    def __init__ (self, sock, addr, onRead, onReport):
        DoesLogging.__init__(self, name='netlogd.InputChannel')
        asyncore.dispatcher.__init__(self, sock)
        Reporter.__init__(self, onReport)
        self._callback = onRead
        self._buf = ""
        self._addr = addr

    def writable(self):
        return False

    def handle_close(self):
        self.log.info("connection.close", host=self._addr[0],
                      port=self._addr[1])
        self.close()

    def handle_read(self):
        data = self.recv(self.block_size)
        if len(data) == 0:
            return
        self._buf += data
        start = sendRecordsToCallback(self._buf, self._callback)
        # Shorten buffer to unsent portion
        if start > 0:
            self._buf = self._buf[start:]
        self.report(len(data))


def sendRecordsToCallback(buf, cb):
    SEP = '\n'
    start = 0
    while 1:
        end = buf.find(SEP, start, len(buf))
        if end == -1:
            break
        cb(buf[start:end+1]) # include separator
        start = end + 1
    return start

class UDPServer(asyncore.dispatcher, Reporter):
    """UDP socket server
    """
    MAX_PACKET_SIZE = 65536

    def __init__(self, port, callback, report_cb):
        asyncore.dispatcher.__init__(self)
        Reporter.__init__(self, report_cb)
        self._readcb = callback
        self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.set_reuse_addr()
        self.bind(('', port))
        self.add_channel()

    def handle_read(self):
        try:
            data, client_addr = self.socket.recvfrom(self.MAX_PACKET_SIZE)
            self._readcb(data)
            self.report(len(data))
        except socket.timeout:
            pass

    def handle_accept(self):
        pass

    def handle_connect(self):
        pass

    def handle_close(self):
        pass

    def writable(self):
        return False

#
## Output classes
#

class Output(DoesLogging):
    def __init__(self, url, flush=False, fmt=None, rollover=None,
                 add_rcvtime=False):
        DoesLogging.__init__(self)
        if url == '-':
            self.url = "{stdout}"
            self._output = sys.stdout
        elif url == '&':
            self.url = "{stderr}"
            self._output = sys.stderr
        else:
            self.url = url
            self._output = file(url,'a')
        self._report_rec, self._report_bytes = 0, 0
        self._trans = string.maketrans(' \t','__')
        # flush flag
        self._flush = flush
        # rollover
        if rollover is not None:
            self._rollover = rollover
            self._bytes = 0
        else:
            self._rollover = None

    def report(self, sec):
        kb = (self._report_bytes/1024)/sec
        rec = self._report_rec/sec
        print ("%s throughput: "
               "%d KB/sec, %d records/sec" % (self.url, kb, rec))
        self._report_rec, self._report_bytes = 0, 0

    def write(self, buf):
        """
        Write buffer to output.
        """
        self._output.write(buf)
        if self._flush:
            self.flush()
        if self._rollover:
            self._bytes += len(buf)
            if self._bytes > self._rollover:
                self.roll_over()
                self._bytes = 0
        self._report_rec += 1
        self._report_bytes += len(buf)

    def flush(self): self._output.flush()

    def roll_over(self):
        """
        Rename current file and re-open original (file) URL.
        """
        time_str = "%04d-%02d-%02dT%02d:%02d:%02d" % \
                 time.gmtime(time.time())[0:6]
        filename = self._output.name
        rolled_filename = filename + '.' + time_str
        self.log.debug("roll.start", from__file=filename,
                       to__file=rolled_filename)
        self._output.close()
        os.rename(filename, rolled_filename)
        self._output = nllite.urlfile(filename)
        self.log.debug("roll.end", status=0)

    def close(self):
        self._output.close()

class MultiOutput:
    def __init__(self, nl_action=False):
        self._outs = []
        self._last_report_time = time.time()
        # rpc/action field
        self._action = nl_action

    def add(self, outp):
        self._outs.append(outp)

    def write(self, buf):
        for o in self._outs:
            o.write(buf)

    def report(self, nbytes):
        report_time = time.time()
        sec = report_time - self._last_report_time
        if len(self._outs) > 1:
            tput = (nbytes/1024) / sec
            print("Overall throughput: %d KB/sec" % tput)
        for o in self._outs:
            o.report(sec)
        self._last_report_time = time.time()

    def flush(self):
        for o in self._outs: o.flush()

    def close(self):
        for o in self._outs: o.close()

def doLoop(loop_time, kill_time):
    if kill_time is None:
        asyncore.loop(loop_time)
    else:
        stop = time.time() + kill_time
        while stop > time.time():
            asyncore.poll(loop_time)
        get_logger(__file__).warn("abort",
                       msg="As requested, netlogd is killing itself. Urk!")

def main(cmdline):
    global g_server, g_outputs, g_debug

    # Setup signal handlers
    map(lambda s: signal.signal(s, shutdown), SHUTDOWN_SIGNALS)
    map(lambda s: signal.signal(s, flushall), FLUSH_SIGNALS)
    # setup options
    desc = ' '.join(__doc__.split())
    parser = OptionParser(description=desc)
    parser.add_option('-b', '--fork', action="store_true", dest="fork",
                      default=False,
                      help="fork into the background after starting up")
    parser.add_option('-f', '--flush', action="store_true", dest="flush",
                      default=True,
                      help="flush all outputs after each record")
    parser.add_option('-k', '--kill', dest="seppuku", type="string",
                      default="", metavar="TIME",
                      help="Kill self after some time. Time can be given "
                      "in units 's', 'm' or 'h' for seconds, minutes or hours. "
                      "Default units are minutes ('m')")
    parser.add_option('-o', "--output", action="append", dest="urls",
                      default=[], metavar="URL",
                      help="Output file(s), repeatable (default=stdout)")
    parser.add_option('-p', '--port', action="store", dest="port", type="int",
                      default=DEFAULT_PORT, metavar="PORT",
                      help="port number (default=%d)" % DEFAULT_PORT)
    parser.add_option('-r', '--rollover', action="store", dest="rollover",
                      default=None, metavar="SIZE",
                      help="roll over files at given file size (units allowed)")
    parser.add_option('-U', '--udp', action='store_true', dest='udpmode',
                      default=False,
                      help="listen on a UDP instead of TCP socket")
    # parse options
    options, args = parser.parse_args(cmdline[1:])
    # sanity check
    if len(options.urls) == 0:
        options.urls.append('-')
    log = get_logger(__file__)  # Should be first done, just after parsing args
    # rollover
    rollover = None
    if options.rollover:
        try:
            rollover = parse_size(options.rollover)
        except BadFileSizeException, E:
            parser.error("%s" % E)
    # create output object(s)
    outp = MultiOutput()
    for oname in options.urls:
        outp.add(Output(oname, options.flush, rollover=rollover))
    g_outputs = outp
    # seppuku (self-kill)
    if options.seppuku == "":
        kill_time = None
    else:
        num, unit = None, None
        try:
            if len(options.seppuku) == 1 or \
                   options.seppuku[-1] not in ('s', 'm', 'h'):
                num, unit = float(options.seppuku), 'm'
            else:
                num, unit = float(options.seppuku[:-1]), options.seppuku[-1]
        except ValueError:
            parser.error("Bad value for self-kill time")
        kill_time = num * {'s':1, 'm':60, 'h':3600}[unit]
    # init server
    if options.verbose > 0:
        report_fn = outp.report
    else:
        report_fn = None
    if options.udpmode:
        g_server = UDPServer(options.port, outp.write, report_fn)
        transport = "UDP"
    else:
        g_server = TCPServer(options.port, outp.write, report_fn)
        transport = "TCP"
    log.info("start", protocol=transport, port=options.port)
    # run server -- forever!
    if options.fork:
        pid = os.fork()
        if pid == 0:
            doLoop(0.1, kill_time)
        else:
            try:
                pidfile = open("netlogd.pid","w")
                pidfile.write("%d\n" % pid)
                pidfile.close()
                log.debug("pid.write", pid=pid, file='netlogd.pid')
            except IOError, err:
                log.warn("pid.write.error", pid=pid, file='netlogd.pid',
                         msg=err)
            return 0
    else:
        doLoop(0.1, kill_time)

if __name__ == '__main__':
    sys.exit(main(sys.argv))
