#!/usr/bin/env python

import argparse
from expvar.stats import stats
import logging
import oid_translate
from pysnmp.carrier.asynsock.dispatch import AsynsockDispatcher
from pysnmp.carrier.asynsock.dgram import udp
import threading
import tornado.ioloop
import tornado.web

from trapperkeeper.callbacks import TrapperCallback
from trapperkeeper.config import Config
from trapperkeeper.models import get_db_engine, Session
from trapperkeeper.utils import get_template_env, get_loglevel, CachingResolver
from trapperkeeper import __version__


def stats_server(port):
    class Stats(tornado.web.RequestHandler):
        def get(self):
            return self.write(stats.to_dict())

    application = tornado.web.Application([
        (r"/debug/stats", Stats),
    ])

    application.listen(port)
    tornado.ioloop.IOLoop.instance().start()



def main():

    parser = argparse.ArgumentParser(description="SNMP Trap Collector.")
    parser.add_argument("-c", "--config", default="/etc/trapperkeeper.yaml",
                        help="Path to config file.")
    parser.add_argument("-v", "--verbose", action="count", default=0, help="Increase logging verbosity.")
    parser.add_argument("-q", "--quiet", action="count", default=0, help="Decrease logging verbosity.")
    parser.add_argument("-V", "--version", action="version",
                        version="%%(prog)s %s" % __version__,
                        help="Display version information.")

    args = parser.parse_args()

    oid_translate.load_mibs()

    config = Config.from_file(args.config)

    db_engine = get_db_engine(config["database"])
    Session.configure(bind=db_engine)

    conn = Session()
    resolver = CachingResolver(timeout=300)
    template_env = get_template_env(
        hostname_or_ip=resolver.hostname_or_ip
    )
    cb = TrapperCallback(conn, template_env, config, resolver)

    logging.basicConfig(
        level=get_loglevel(args),
        format="%(asctime)-15s\t%(levelname)s\t%(message)s"
    )

    transport_dispatcher = AsynsockDispatcher()
    transport_dispatcher.registerRecvCbFun(cb)
    transport_dispatcher.registerTransport(
        udp.domainName, udp.UdpSocketTransport().openServerMode(("0.0.0.0", int(config["trap_port"])))
    )

    transport_dispatcher.jobStarted(1)
    stats_thread = threading.Thread(target=stats_server, args=(str(config["stats_port"]),))

    try:
        stats_thread.start()
        transport_dispatcher.runDispatcher()
    except KeyboardInterrupt:
        pass
    finally:
        print "Stopping Transport Dispatcher..."
        transport_dispatcher.closeDispatcher()
        print "Stopping Stats Thread..."
        tornado.ioloop.IOLoop.instance().stop()
        stats_thread.join()
        print "Bye"


if __name__ == "__main__":
    main()
