#!/usr/bin/env python
import sys
import os
import web
import json
import logging

import pika

from time import time, sleep
from threading import Thread

from shoal_server import config as config
from shoal_server import geoip as geoip
from shoal_server import urls as urls


LOG_FORMAT = ('%(levelname) -10s %(asctime)s %(name) -30s %(funcName) '
                      '-35s %(lineno) -5d: %(message)s')
logging.getLogger('shoal')
logging.basicConfig(level=logging.ERROR, format=LOG_FORMAT)

"""
    Basic class to store and update information about each squid server.
"""
class SquidNode(object):
    def __init__(self, key, hostname, public_ip, private_ip, external_ip, load, geo_data, last_active=time()):
        self.key = key
        self.created = time()
        self.last_active = last_active
        self.hostname = hostname

        self.public_ip = public_ip
        self.private_ip = private_ip
        self.external_ip = external_ip
        self.load = load
        self.geo_data = geo_data

    def update(self, public_ip, private_ip, load, geo_data):
        self.last_active = time()

        self.public_ip = public_ip
        self.private_ip = private_ip
        self.load = load
        self.geo_data = geo_data

"""
    Main application that will delegate threads.
"""
class Application(object):

    def __init__(self):
        # setup configuration settings.
        config.setup()
        # change working directory so webpy static files load correctly.
        try:
            os.chdir(config.shoal_dir)
        except OSError as e:
            logging.error("{0} doesn't seem to exist. Please set `shoal_dir` in shoal-server config file to the location of the shoal-server static files. (Perhaps its ~/shoal_server/?)".format(config.shoal_dir))
            sys.exit(1)

        self.shoal = {}
        self.threads = []

        # check if geolitecity database needs updating
        if geoip.check_geolitecity_need_update():
            geoip.download_geolitecity()

        rabbitmq_thread = Thread(target=self.rabbitmq, name='RabbitMQ')
        rabbitmq_thread.daemon = True
        self.threads.append(rabbitmq_thread)

        webpy_thread = Thread(target=self.webpy, name='Webpy')
        webpy_thread.daemon = True
        self.threads.append(webpy_thread)

        update_thread = Thread(target=self.update, name="ShoalUpdate")
        update_thread.daemon = True
        self.threads.append(update_thread)

        for thread in self.threads:
            thread.start()

        try:
            while True:
                for thread in self.threads:
                    if not thread.is_alive():
                        logging.error('{0} died.'.format(thread))
                        self.stop()
                sleep(1)
        except KeyboardInterrupt:
            self.stop()

    def rabbitmq(self):
        url = config.amqp_server_url
        self.rabbitmq = RabbitMQConsumer(url, self.shoal)
        self.rabbitmq.run()

    def webpy(self):
        self.webpy = WebpyServer(self.shoal)
        self.webpy.run()

    def update(self):
        self.update = ShoalUpdate(self.shoal)
        self.update.run()

    def stop(self):
        print "Shutting down Shoal-Server... Please wait."
        try:
            self.webpy.stop()
            print "Web Server stopped."
            self.rabbitmq.stop()
            print "RabbitMQ consumer stopped."
            self.update.stop()
        except Exception as e:
            logging.error(e)
        finally:
            # give them time to properly exit.
            sleep(2)
            sys.exit()

"""
    ShoalUpdate is used for trimming inactive squids every set interval.
"""
class ShoalUpdate(object):

    def __init__(self, shoal):
        self.shoal = shoal
        self.interval = config.squid_cleanse_interval
        self.inactive = config.squid_inactive_time
        self.running = False

    def run(self):
        self.running = True
        while self.running:
            sleep(self.interval)
            self.update()

    def update(self):
        curr = time()
        for squid in self.shoal.values():
            if curr - squid.last_active > self.inactive:
                self.shoal.pop(squid.key)

    def stop(self):
        self.running = False

"""
    Webpy webserver used to serve up active squid lists and API calls. For now we just use the development webpy server to serve requests.
"""
class WebpyServer(object):

    def __init__(self, shoal):
        web.shoal = shoal
        web.config.debug = False
        self.app = None
        self.urls = (
            '/nearest/?(\d+)?/?', 'shoal_server.urls.nearest',
            '/wpad.dat', 'shoal_server.urls.wpad',
            '/external', 'shoal_server.urls.external_ip',
            '/(\d+)?/?', 'shoal_server.urls.index',
        )
    def run(self):
        try:
            self.app = web.application(self.urls, globals())
            self.app.run()
        except Exception as e:
            logging.error("Could not start webpy server.\n{0}".format(e))
            sys.exit(1)

    def stop(self):
        self.app.stop()

"""
    Basic RabbitMQ blocking consumer. Consumes messages from `config.amqp_server_queue` takes the json in body, and put it into the dictionary `shoal`
    Messages received must be a json string with keys or it will be discarded.
    {
      'uuid': '1231232',
      'public_ip': '142.11.52.1',
      'private_ip: '192.168.0.1',
      'load': '12324432',
      'timestamp':'2121231313',
    }
"""
class RabbitMQConsumer(object):

    QUEUE = config.amqp_server_queue
    EXCHANGE = config.amqp_exchange
    EXCHANGE_TYPE = config.amqp_exchange_type
    ROUTING_KEY = '#'

    def __init__(self, amqp_url, shoal):
        """Create a new instance of the consumer class, passing in the AMQP
        URL used to connect to RabbitMQ.

        :param str amqp_url: The AMQP url to connect with

        """
        self.shoal = shoal
        self._connection = None
        self._channel = None
        self._closing = False
        self._consumer_tag = None
        self._url = amqp_url

    def connect(self):
        return pika.SelectConnection(pika.URLParameters(self._url),
                                     self.on_connection_open,
                                     stop_ioloop_on_close=False)

    def close_connection(self):
        self._connection.close()

    def add_on_connection_close_callback(self):
        self._connection.add_on_close_callback(self.on_connection_closed)

    def on_connection_closed(self, connection, reply_code, reply_text):
        self._channel = None
        if self._closing:
            self._connection.ioloop.stop()
        else:
            logging.warning('Connection closed, reopening in 5 seconds: (%s) %s',
                           reply_code, reply_text)
            self._connection.add_timeout(5, self.reconnect)

    def on_connection_open(self, unused_connection):
        self.add_on_connection_close_callback()
        self.open_channel()

    def reconnect(self):
        # This is the old connection IOLoop instance, stop its ioloop
        self._connection.ioloop.stop()

        if not self._closing:

            # Create a new connection
            self._connection = self.connect()

            # There is now a new connection, needs a new ioloop to run
            self._connection.ioloop.start()

    def add_on_channel_close_callback(self):
        self._channel.add_on_close_callback(self.on_channel_closed)

    def on_channel_closed(self, channel, reply_code, reply_text):
        logging.warning('Channel was closed: (%s) %s', reply_code, reply_text)
        self._connection.close()

    def on_channel_open(self, channel):
        self._channel = channel
        self.add_on_channel_close_callback()
        self.setup_exchange(self.EXCHANGE)

    def setup_exchange(self, exchange_name):
        self._channel.exchange_declare(self.on_exchange_declareok,
                                       exchange_name,
                                       self.EXCHANGE_TYPE)

    def on_exchange_declareok(self, unused_frame):
        self.setup_queue(self.QUEUE)

    def setup_queue(self, queue_name):
        self._channel.queue_declare(self.on_queue_declareok, queue_name)

    def on_queue_declareok(self, method_frame):
        self._channel.queue_bind(self.on_bindok, self.QUEUE,
                                 self.EXCHANGE, self.ROUTING_KEY)

    def add_on_cancel_callback(self):
        self._channel.add_on_cancel_callback(self.on_consumer_cancelled)

    def on_consumer_cancelled(self, method_frame):
        if self._channel:
            self._channel.close()

    def acknowledge_message(self, delivery_tag):
        self._channel.basic_ack(delivery_tag)

    def on_cancelok(self, unused_frame):
        self.close_channel()

    def stop_consuming(self):
        if self._channel:
            self._channel.basic_cancel(self.on_cancelok, self._consumer_tag)

    def start_consuming(self):
        self.add_on_cancel_callback()
        self._consumer_tag = self._channel.basic_consume(self.on_message,
                                                         self.QUEUE)

    def on_bindok(self, unused_frame):
        self.start_consuming()

    def close_channel(self):
        self._channel.close()

    def open_channel(self):
        self._connection.channel(on_open_callback=self.on_channel_open)

    def run(self):
        self._connection = self.connect()
        self._connection.ioloop.start()

    def stop(self):
        self._closing = True
        self.stop_consuming()
        self._connection.ioloop.start()

    def on_message(self, unused_channel, basic_deliver, properties, body):
        try:
            external_ip = public_ip = private_ip = None
            squid_inactive_time = config.squid_inactive_time
            curr = time()
            data = json.loads(body)

            key = data['uuid']
            hostname = data['hostname']
            last_active = data['timestamp']
            load = data['load']

            if 'external_ip' in data:
                external_ip = data['external_ip']
            if 'public_ip' in data:
                public_ip = data['public_ip']
            if 'private_ip' in data:
                private_ip = data['private_ip']

            try:
                geo_data = geoip.get_geolocation(public_ip.split(':')[0])
            except:
                geo_data = geoip.get_geolocation(external_ip)

            if not geo_data:
                logging.error("Unable to generate geo location data, discarding message")
            else:
                if key in self.shoal:
                    self.shoal[key].update(public_ip, private_ip, load, geo_data)
                elif curr - last_active < squid_inactive_time:
                    new_squid = SquidNode(key, hostname, public_ip, private_ip, external_ip, load, geo_data, last_active)
                    self.shoal[key] = new_squid

        except KeyError as e:
            logging.error("Message received was not the proper format (missing:{0}), discarding...\nmethod_frame:{1}\nproperties:{2}\nbody:{3}\n".format(e,method_frame,properties,body))

        finally:
            self.acknowledge_message(basic_deliver.delivery_tag)

def main():
    Application()

main()
