#!/usr/bin/env python
# encoding: utf-8
"""
LISPd manages all LISP control packets sent and received by
a system. By default it listens on UDP port 4342, dispatches
incoming requests to the configurable modules and can send
requests to other systems on behalf of other applications.
"""

from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ipaddress import ip_address, IPv4Address
from multiprocessing.dummy import Pool
from pylisp.application.lispd import settings
from pylisp.application.lispd.message_handler import handle_message
from pylisp.application.lispd.received_message import ReceivedMessage
from pylisp.application.lispd.settings import ConfigurationError
from pylisp.packet.ip.ipv4 import IPv4Packet
from pylisp.packet.ip.ipv6.base import IPv6Packet
from pylisp.packet.lisp.control import ControlMessage
from pylisp.packet.lisp.data import DataPacket
from pylisp.utils.task_thread import TaskThread
import logging
import select
import socket
import sys
from pylisp.application.lispd.send_message import send_message


try:
    import nfqueue  # @UnresolvedImport
except ImportError:
    nfqueue = None


logger = logging.root


def create_sockets(config):
    control_plane_sockets = []
    data_plane_sockets = []
    for address in config.LISTEN_ON:
        # (re)parse the address
        address = ip_address(unicode(address))

        # Determine the address family
        if isinstance(address, IPv4Address):
            family = socket.AF_INET
        else:
            family = socket.AF_INET6

        # Refuse to bind to loopback, link-local, wild-card or multicast addresses
        if address.is_loopback:
            logger.error(u"Cannot bind to loopback address {0}".format(address))
            continue

        if address.is_link_local:
            logger.error(u"Cannot bind to link-local address {0}".format(address))
            continue

        if address.is_multicast:
            logger.error(u"Cannot bind to multicast address {0}".format(address))
            continue

        if address.is_unspecified:
            logger.error(u"Cannot bind to unspecified address {0}".format(address))
            continue

        # Convert back to a string
        address = unicode(address)

        logger.info(u"Binding data-plane socket to %s port %d" % (address, 4341))

        # Create the socket and bind to it
        sock = socket.socket(family, socket.SOCK_DGRAM, socket.SOL_UDP)
        sock.bind((address, 4341))
        data_plane_sockets.append(sock)

        logger.info(u"Binding control-plane socket to %s port %d" % (address, 4342))

        # Create the socket and bind to it
        sock = socket.socket(family, socket.SOCK_DGRAM, socket.SOL_UDP)
        sock.bind((address, 4342))
        control_plane_sockets.append(sock)

    return control_plane_sockets, data_plane_sockets


class nfqueue_callback:
    def __init__(self, config, data_plane_sockets):
        self.config = config
        self.sockets = data_plane_sockets

    def __call__(self, payload):
        # Always drop the packet, we'll send a new one
        payload.set_verdict(nfqueue.NF_DROP)

        if not settings.config.PETR:
            logger.error("PETR is not configured, dropping packet")
            return

        try:
            logger.debug("NFQUEUE callback called: {0} bytes of data".format(payload.get_length()))
            data = payload.get_data()

#             family = ord(data[0]) >> 4
#             if family == 4:
#                 packet = IPv4Packet.from_bytes(data)
#             elif family == 6:
#                 packet = IPv6Packet.from_bytes(data)
#             else:
#                 logger.warn("Unknown data packet family {0}".format(family))
#
#             logger.debug("Data contents: {0!r}".format(packet))

            # Encapsulate data
            message = DataPacket(payload=data)
            send_message(message=message, my_sockets=self.sockets, destinations=[settings.config.PETR], port=4341)

        except:
            logger.exception("Unexpected exception when handling data packet")


class BackgroundProcessingTask(TaskThread):
    def __init__(self, config, my_sockets):
        super(BackgroundProcessingTask, self).__init__(interval=10)
        self.config = config
        self.my_sockets = my_sockets

    def task(self):
        logger.debug("Running background processing")
        with self.config.lock:
            for instance in self.config.INSTANCES:
                for afi in self.config.INSTANCES[instance]:
                    self.config.INSTANCES[instance][afi].process(self.my_sockets)
        logger.debug("Finished background processing")


def main(argv=None):
    """Command line options."""

    if argv is None:
        argv = sys.argv
    else:
        sys.argv.extend(argv)

    try:
        # Setup argument parser
        parser = ArgumentParser(description=__doc__,
                                formatter_class=RawDescriptionHelpFormatter)
        parser.add_argument("-v",
                            "--verbose",
                            dest="verbose",
                            action="store_true",
                            help="be verbose")
        parser.add_argument("-d",
                            "--debug",
                            dest="debug",
                            action="store_true",
                            help="show debugging output")
        parser.add_argument("-C",
                            "--show-config",
                            dest="show_config",
                            action="store_true",
                            help="show the configuration and exit")

        # Process arguments
        args = parser.parse_args()

        # Configure the logging process
        if args.debug:
            logging_level = logging.DEBUG
        elif args.verbose:
            logging_level = logging.INFO
        else:
            logging_level = logging.WARNING

        logging.basicConfig(level=logging_level,
                            format='%(asctime)s [%(module)s %(levelname)s] %(message)s')

        # Init the settings
        settings.config = settings.Settings()

        # Show config?
        if args.show_config:
            for setting, value in settings.config.__dict__.iteritems():
                if setting == setting.upper():
                    # Show the setting
                    sys.stdout.write("%s=%r\n" % (setting, value))
            return 2

        # Determine local sockets
        control_plane_sockets, data_plane_sockets = create_sockets(settings.config)
        raw_socket_ipv4 = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_RAW)
        raw_socket_ipv6 = socket.socket(socket.AF_INET6, socket.SOCK_RAW, socket.IPPROTO_RAW)

        if settings.config.PETR and (settings.config.NFQUEUE_IPV4 is None or settings.config.NFQUEUE_IPV6 is None):
            logger.error("PETR configured but not NFQUEUE_IPV4 and NFQUEUE_IPV6")
            return 2

        # Do we use nfqueue?
        if settings.config.PETR:
            if not nfqueue:
                logger.error("Python nfqueue module not found")
                return 2

            # An IPv4 queue
            callback = nfqueue_callback(settings.config, data_plane_sockets)

            nfqueue_ipv4 = nfqueue.queue()
            nfqueue_ipv4.set_callback(callback)
            nfqueue_ipv4.fast_open(settings.config.NFQUEUE_IPV4, socket.AF_INET)
            nfqueue_ipv4.set_queue_maxlen(5000)
            nfqueue_ipv4.set_mode(nfqueue.NFQNL_COPY_PACKET)
            nfqueue_ipv4.fileno = nfqueue_ipv4.get_fd

            nfqueue_ipv6 = nfqueue.queue()
            nfqueue_ipv6.set_callback(callback)
            nfqueue_ipv6.fast_open(settings.config.NFQUEUE_IPV6, socket.AF_INET6)
            nfqueue_ipv6.set_queue_maxlen(5000)
            nfqueue_ipv6.set_mode(nfqueue.NFQNL_COPY_PACKET)
            nfqueue_ipv6.fileno = nfqueue_ipv6.get_fd

            nfqueues = [nfqueue_ipv4, nfqueue_ipv6]
        else:
            nfqueues = []

        if not control_plane_sockets:
            logger.error("Not listening on any addresses")
            return 2

        # Create the thread pool
        pool = Pool(processes=settings.config.THREAD_POOL_SIZE)

        # Create the timer for processing
        processing_timer = BackgroundProcessingTask(settings.config, control_plane_sockets)
        processing_timer.start()

        logger.info("Waiting for incoming messages to process")
        while True:
            try:
                # Get the active sockets
                rlist, dummy, dummy = select.select(control_plane_sockets + data_plane_sockets + nfqueues, [], [])

                for sock in rlist:
                    if sock in control_plane_sockets:
                        data, addr = sock.recvfrom(65536)
                        logger.debug("Received %d bytes of control-plane traffic from %r", len(data), addr)

                        # Parse input
                        try:
                            message = ControlMessage.from_bytes(data)
                            received_message = ReceivedMessage(source=addr,
                                                               destination=sock.getsockname(),
                                                               message=message,
                                                               socket=sock)
                        except Exception, e:
                            logger.error("Error in control-plane message from %r: %s", addr, e)

                        # Dispatch
                        try:
                            kwds = {'received_message': received_message,
                                    'my_sockets': control_plane_sockets}
                            pool.apply_async(handle_message, kwds=kwds)
                        except Exception, e:
                            logger.error("Uncaught exception when handling control-plane message from %r: %s", addr, e)
                    elif sock in data_plane_sockets:
                        data, addr = sock.recvfrom(65536)
                        logger.debug("Received %d bytes of data-plane traffic from %r", len(data), addr)

                        message = DataPacket.from_bytes(data)
                        logger.debug("Seems to contain {0!r}".format(message))

                        data = message.payload.to_bytes()
                        if isinstance(message.payload, IPv4Packet):
                            raw_socket = raw_socket_ipv4
                        else:
                            raw_socket = raw_socket_ipv6

                        sent = raw_socket.sendto(data, (unicode(message.payload.destination), 0))
                        if sent != len(data):
                            logger.error("Could not send decapsulated packet {0!r}".format(message.payload))
                        else:
                            logger.debug("Sent decapsulated packet of {0} bytes".format(sent))

                    elif sock in nfqueues:
                        logger.debug("Triggered NFQUEUE")
                        sock.process_pending(5)

            except KeyboardInterrupt:
                logger.info("Interupted")
                break
            except Exception, e:
                logger.exception("Unexpected exception: %s" % e)

        logger.info("Shutting down")

        # Wait for the workers
        processing_timer.shutdown()
        pool.close()
        pool.join()

        logger.info("LISPd shut down")

        return 0
    except KeyboardInterrupt:
        ### handle keyboard interrupt ###
        return 0
    except ConfigurationError:
        return 1
    except Exception, e:
        logger.exception("Unexpected exception")
        return 2


if __name__ == "__main__":
    sys.exit(main())
