#!/usr/bin/env python
import sys
import os
import web
import json
import urllib
import logging

import pika

from time import time, sleep
from threading import Thread

from shoal_server import config as config
from shoal_server import geoip as geoip
from shoal_server import urls as urls

LOG_FORMAT = '%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)s] - %(message)s'
logger = logging.getLogger('shoal-server')

"""
    Basic class to store and update information about each squid server.
"""
class SquidNode(object):

    def __init__(self, key, hostname, squid_port, public_ip, private_ip, external_ip, load, geo_data, last_active=time()):
        self.key = key
        self.created = time()
        self.last_active = last_active
        self.hostname = hostname
        self.squid_port = squid_port
        self.public_ip = public_ip
        self.private_ip = private_ip
        self.external_ip = external_ip
        self.geo_data = geo_data
        self.load = load

    def update(self, load):
        self.last_active = time()
        self.load = load

"""
    Main application that will delegate threads.
"""
class Application(object):

    LOG_FILE = config.log_file
    DIRECTORY = config.shoal_dir

    def __init__(self):
        # setup configuration settings.
        config.setup()

        try:
            logging.basicConfig(level=logging.ERROR, format=LOG_FORMAT, filename=self.LOG_FILE)
        except IOError as e:
            print "Could not set logger.", e
            sys.exit(1)

        # change working directory so webpy static files load correctly.
        try:
            os.chdir(self.DIRECTORY)
        except OSError as e:
            print "{0} doesn't seem to exist. Please set `shoal_dir` in shoal-server config file to the location of the shoal-server static files.".format(self.DIRECTORY)
            sys.exit(1)

        # check if geolitecity database needs updating
        if geoip.check_geolitecity_need_update():
            geoip.download_geolitecity()

        self.shoal = {}
        self.threads = []

        rabbitmq_thread = Thread(target=self.rabbitmq, name='RabbitMQ')
        rabbitmq_thread.daemon = True
        self.threads.append(rabbitmq_thread)

        webpy_thread = Thread(target=self.webpy, name='Webpy')
        webpy_thread.daemon = True
        self.threads.append(webpy_thread)

        update_thread = Thread(target=self.update, name="ShoalUpdate")
        update_thread.daemon = True
        self.threads.append(update_thread)

        for thread in self.threads:
            thread.start()
        try:
            while True:
                for thread in self.threads:
                    if not thread.is_alive():
                        logger.error('{0} died.'.format(thread))
                        self.stop()
                sleep(1)
        except KeyboardInterrupt:
            self.stop()

    def rabbitmq(self):
        url, vh = config.amqp_server_url, config.amqp_virtual_host
        host = "{0}/{1}".format(url, urllib.quote_plus(vh))
        self.rabbitmq = RabbitMQConsumer(host, self.shoal)
        self.rabbitmq.run()

    def webpy(self):
        self.webpy = WebpyServer(self.shoal)
        self.webpy.run()

    def update(self):
        self.update = ShoalUpdate(self.shoal)
        self.update.run()

    def stop(self):
        print "\nShutting down Shoal-Server... Please wait."
        try:
            self.webpy.stop()
            print "Web Server stopped."
            self.rabbitmq.stop()
            print "RabbitMQ consumer stopped."
            self.update.stop()
        except Exception as e:
            logger.error(e)
            sys.exit(1)
        finally:
            sleep(2)
        sys.exit()

"""
    ShoalUpdate is used for trimming inactive squids every set interval.
"""
class ShoalUpdate(object):

    INTERVAL = config.squid_cleanse_interval
    INACTIVE = config.squid_inactive_time

    def __init__(self, shoal):
        self.shoal = shoal
        self.running = False

    def run(self):
        self.running = True
        while self.running:
            sleep(self.INTERVAL)
            self.update()

    def update(self):
        curr = time()
        for squid in self.shoal.values():
            if curr - squid.last_active > self.INACTIVE:
                self.shoal.pop(squid.key)

    def stop(self):
        self.running = False

"""
    Webpy webserver used to serve up active squid lists and API calls. For now we just use the development webpy server to serve requests.
"""
class WebpyServer(object):

    def __init__(self, shoal):
        web.shoal = shoal
        web.config.debug = False
        self.app = None
        self.urls = (
            '/nearest/?(\d+)?/?', 'shoal_server.urls.nearest',
            '/wpad.dat', 'shoal_server.urls.wpad',
            '/external', 'shoal_server.urls.external_ip',
            '/(\d+)?/?', 'shoal_server.urls.index',
        )
    def run(self):
        try:
            self.app = web.application(self.urls, globals())
            self.app.run()
        except Exception as e:
            logger.error("Could not start webpy server.\n{0}".format(e))
            sys.exit(1)

    def stop(self):
        self.app.stop()

"""
    Basic RabbitMQ blocking consumer. Consumes messages from `config.amqp_server_queue` takes the json in body, and put it into the dictionary `shoal`
"""
class RabbitMQConsumer(object):

    QUEUE = config.amqp_server_queue
    EXCHANGE = config.amqp_exchange
    EXCHANGE_TYPE = config.amqp_exchange_type
    ROUTING_KEY = '#'
    INACTIVE = config.squid_inactive_time

    def __init__(self, amqp_url, shoal):
        """Create a new instance of the consumer class, passing in the AMQP
        URL used to connect to RabbitMQ.

        :param str amqp_url: The AMQP url to connect with

        """
        self.shoal = shoal
        self._connection = None
        self._channel = None
        self._closing = False
        self._consumer_tag = None
        self._url = amqp_url

    def connect(self):
        return pika.SelectConnection(pika.URLParameters(self._url),
                                         self.on_connection_open,
                                         stop_ioloop_on_close=False)

    def close_connection(self):
        self._connection.close()

    def add_on_connection_close_callback(self):
        self._connection.add_on_close_callback(self.on_connection_closed)

    def on_connection_closed(self, connection, reply_code, reply_text):
        self._channel = None
        if self._closing:
            self._connection.ioloop.stop()
        else:
            logger.warning('Connection closed, reopening in 5 seconds: (%s) %s',
                           reply_code, reply_text)
            self._connection.add_timeout(5, self.reconnect)

    def on_connection_open(self, unused_connection):
        self.add_on_connection_close_callback()
        self.open_channel()

    def reconnect(self):
        # This is the old connection IOLoop instance, stop its ioloop
        self._connection.ioloop.stop()

        if not self._closing:

            # Create a new connection
            self._connection = self.connect()

            # There is now a new connection, needs a new ioloop to run
            self._connection.ioloop.start()

    def add_on_channel_close_callback(self):
        self._channel.add_on_close_callback(self.on_channel_closed)

    def on_channel_closed(self, channel, reply_code, reply_text):
        logger.warning('Channel was closed: (%s) %s', reply_code, reply_text)
        self._connection.close()

    def on_channel_open(self, channel):
        self._channel = channel
        self.add_on_channel_close_callback()
        self.setup_exchange(self.EXCHANGE)

    def setup_exchange(self, exchange_name):
        self._channel.exchange_declare(self.on_exchange_declareok,
                                       exchange_name,
                                       self.EXCHANGE_TYPE)

    def on_exchange_declareok(self, unused_frame):
        self.setup_queue(self.QUEUE)

    def setup_queue(self, queue_name):
        self._channel.queue_declare(self.on_queue_declareok, queue_name)

    def on_queue_declareok(self, method_frame):
        self._channel.queue_bind(self.on_bindok, self.QUEUE,
                                 self.EXCHANGE, self.ROUTING_KEY)

    def add_on_cancel_callback(self):
        self._channel.add_on_cancel_callback(self.on_consumer_cancelled)

    def on_consumer_cancelled(self, method_frame):
        if self._channel:
            self._channel.close()

    def acknowledge_message(self, delivery_tag):
        self._channel.basic_ack(delivery_tag)

    def on_cancelok(self, unused_frame):
        self.close_channel()

    def stop_consuming(self):
        if self._channel:
            self._channel.basic_cancel(self.on_cancelok, self._consumer_tag)

    def start_consuming(self):
        self.add_on_cancel_callback()
        self._consumer_tag = self._channel.basic_consume(self.on_message,
                                                         self.QUEUE)

    def on_bindok(self, unused_frame):
        self.start_consuming()

    def close_channel(self):
        self._channel.close()

    def open_channel(self):
        self._connection.channel(on_open_callback=self.on_channel_open)

    def run(self):
        try:
            self._connection = self.connect()
        except Exception as e:
            logger.error("Unable to connect ot RabbitMQ Server. {0}".format(e))
            sys.exit(1)
        self._connection.ioloop.start()

    def stop(self):
        self._closing = True
        self.stop_consuming()
        self._connection.ioloop.start()

    def on_message(self, unused_channel, basic_deliver, properties, body):
        external_ip = public_ip = private_ip = None
        curr = time()

        try:
            data = json.loads(body)
        except ValueError as e:
            logger.error("Message body could not be decoded. Message: {1}".format(body))
            self.acknowledge_message(basic_deliver.delivery_tag)
            return
        try:
            key = data['uuid']
            hostname = data['hostname']
            time_sent = data['timestamp']
            load = data['load']
            squid_port = data['squid_port']
        except KeyError as e:
            logger.error("Message received was not the proper format (missing:{0}), discarding...".format(e))
            self.acknowledge_message(basic_deliver.delivery_tag)
            return
        try:
            external_ip = data['external_ip']
        except KeyError:
            pass
        try:
            public_ip = data['public_ip']
        except KeyError:
            pass
        try:
            private_ip = data['private_ip']
        except KeyError:
            pass

        if key in self.shoal:
            self.shoal[key].update(load)
        elif (curr - time_sent < self.INACTIVE) and (public_ip or private_ip):
            try:
                geo_data = geoip.get_geolocation(public_ip.split(':')[0])
            except IndexError:
                geo_data = geoip.get_geolocation(external_ip)
            if not geo_data:
                logger.error("Unable to generate geo location data, discarding message")
            else:
                new_squid = SquidNode(key, hostname, squid_port, public_ip, private_ip, external_ip, load, geo_data, time_sent)
                self.shoal[key] = new_squid

        self.acknowledge_message(basic_deliver.delivery_tag)

def main():
    try:
        app = Application()
    except KeyboardInterrupt:
        sys.exit()

main()
