import json
import tempfile
import logging
import boto.ec2
import boto.ec2.cloudwatch
import MySQLdb
import random
import socket
import os
import yaml
import paramiko
from qds_ops.utils.fork import Fork
from qds_ops.utils.ssh_tunnel import SSHTunnel


class Knife:
    # http://code.activestate.com/recipes/66531-singleton-we-dont-need-no-stinkin-singleton-the-bo/
    __shared_state = {}
    pemfile = None
    creds = None

    def __init__(self):
        self.__dict__ = self.__shared_state

    def cleanup(self):
        if self.pemfile is not None:
            os.unlink(self.pemfile)

    def search(__self__, env, tier, extra_filters=None):
        search_string = "chef_environment:%s AND qubole_tier:%s" % (env, tier)
        if extra_filters is not None:
            filters = " AND ".join(extra_filters)
            search_string = "%s AND %s" % (search_string, filters)

        output = Fork.check_output(["knife", "search", "node", search_string,
                                    "-a", "ec2.public_hostname", "--format", "yaml"], stderr=None)
        nodes = yaml.load(output)
        nodelist = []
        for n in nodes[':rows']:
            nodelist.append(n['ec2.public_hostname'])
        return nodelist

    def packages(__self__, env, tier):
        ruby_code = """
nodeinfo = []; 
nodes.find(:chef_environment => "%s",:qubole_tier => "%s") { |n| 
  info={};
  info[:public_hostname] = n.ec2.public_hostname; 
  info[:packages]=n.default.qubole.to_hash;nodeinfo << info
}; 
puts nodeinfo.to_json""" % (env, tier)

        output = Fork.check_output(["knife", "exec", "-E", ruby_code], stderr=None)
        nodes = json.loads(output)
        return nodes

    def roles(self):
        output = Fork.check_output(["knife", "role", "list"], stderr=None)
        return output.split("\n")

    def envs(self):
        output = Fork.check_output(["knife", "environment", "list", "-f", "json"], stderr=None)
        info = json.loads(output)
        return info

    def nodes(self, environment):
        output = Fork.check_output(["knife", "node", "list", "-E", environment, "-f", "json"], stderr=None)
        info = json.loads(output)
        return info
    
    def getnodes(self, selector):
        output = Fork.check_output(["knife", "search", "node", selector, "-fj", "-a" "fqdn"], stderr=None)
        info = json.loads(output)
        return info
    

    def describe(self, environment):
        output = Fork.check_output(["knife", "environment", "show", environment, "-f", "json"], stderr=None)
        info = json.loads(output)
        return info

    def get_credentials(self, environment):
        if self.creds is None:
            output = Fork.check_output(["knife", "data", "bag", "show", "creds",
                                    "env_%s" % (environment), "--secret-file", "/etc/chef/encrypted_data_bag_secret",
                                    "-f", "yaml"], stderr=None)
            self.creds = yaml.load(output)
            logging.debug("Obtained creds from knife")
        else:
            logging.debug("Obtained creds from cache")
        return self.creds

    def get_pemfile(self, env):
        if self.pemfile is None:
            creds = self.get_credentials(env)
            t = tempfile.NamedTemporaryFile(delete=False)
            t.write(creds["private_key"])
            t.close()
            self.pemfile = t.name
            logging.info("Pemfile @ %s" % self.pemfile)
            logging.debug("Obtained creds from knife")
        else:
            logging.debug("Obtained creds from cache")
        return self.pemfile

    def ssh(self, cmd, env, tier, extra_filters=None):
        temp = self.get_pemfile(env)
        search_string = "chef_environment:%s AND qubole_tier:%s" % (env, tier)
        if extra_filters is not None:
            filters = " AND ".join(extra_filters)
            search_string = "%s AND %s" % (search_string, filters)
        output = Fork.check_output(
            ["knife", "ssh", search_string, cmd, "-x", "ec2-user", "-i", temp, "-a", "ec2.public_hostname",
             "--no-host-key-verify"],
            stderr=None)
        return output

    def ssh_one(self, cmd, env, tier):
        temp = self.get_pemfile(env)
        k = paramiko.RSAKey.from_private_key_file(temp)
        list = self.search(env, tier)
        logging.debug("Executing command on %s" % list[0])
        shell = paramiko.SSHClient()
        shell.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        shell.connect(list[0], username="ec2-user", pkey=k)
        stdin, stdout, stderr = shell.exec_command(cmd)
        return stdout.read()

    def remove_node(self, node, cw_obj):
        logging.info("Found dead node: " + str(node))
        output = Fork.check_output(["knife", "node", "delete", node, "--yes"], stderr=None)
        print output
        output = Fork.check_output(["knife", "client", "delete", node, "--yes"], stderr=None)
        print output
        alarms = cw_obj.describe_alarms(alarm_name_prefix=node)
        logging.debug("Alarms for node: " + str(node) + " found = " + str(len(alarms)))
        for a in alarms:
            a.delete()

    def reap_instances(self, args):
        if args.environment == "development" or args.environment == "_default":
            return
        logging.info("Checking reap instances for env: " + str(args.environment))
        creds = self.get_credentials(args.environment)
        conn = boto.ec2.connect_to_region("us-east-1",
                                          aws_access_key_id=creds["access"],
                                          aws_secret_access_key=creds["secret"])

        cw_obj = boto.ec2.cloudwatch.connect_to_region('us-east-1',
                                                       aws_access_key_id=creds["access"],
                                                       aws_secret_access_key=creds["secret"])

        nodes = self.nodes(args.environment)
        instances = conn.get_only_instances(nodes)
        # Reap instances of terminated nodes
        for i in instances:
            if i.state == "shutting-down" or i.state == "terminated" or i.state == "stopped":
                self.remove_node(i.id, cw_obj)
            else:
                logging.info("Node: " + str(i.id) + " has state: " + str(i.state))
            nodes.remove(i.id)

        #Reap unknown instances
        for node in nodes:
            self.remove_node(node)


    def deploy_lock(self, environment, tunnel, tier, action):
        pemfile = self.get_pemfile(environment)
        creds = self.get_credentials(environment)
        hostname = creds["host"]
        port = 3306
        # Setup a tunnel
        if tunnel == True:
            list = self.search(environment, "web")
            port = random.randrange(1, 1000) + 33000
            counter = 1
            while not self.checkport(port) and counter < 5:
                port = random.randrange(1, 1000) + 33000
                counter = counter + 1

            if counter > 5:
                raise RuntimeError("Could not find a free port")
            tunnel = SSHTunnel("ec2-user", list[0], hostname, pemfile, 22, 3306, port)
            hostname = "127.0.0.1"
        con = None
        #Connect to DB and run query
        success = False
        try:
            con = MySQLdb.connect(hostname, creds["username"], creds["password"], 'rstore', port);
            if action == "check":
                sql_query = "select running from rstore.deploy_statuses"
                if tier != "web":
                    sql_query = "select running from rstore.tier_deploy_statuses where tier = '%s'" % (tier)
                logging.debug(sql_query)
                cur = con.cursor(MySQLdb.cursors.DictCursor)
                cur.execute(sql_query)
                rows = cur.fetchone()
                logging.debug(str(rows))
                if rows["running"] == 0:
                    print "Unlocked"
                else:
                    print "Locked"
                success = True    
            elif action == "lock":
                sql_query = "update rstore.deploy_statuses set running = true where running = false"
                if tier != "web":
                    sql_query = "update rstore.tier_deploy_statuses set running = true where running = false and tier = '%s'" % (tier)
                logging.debug(sql_query)
                cur = con.cursor()
                cur.execute(sql_query)
                if cur.rowcount > 0:
                    print "Deploy-Lock locked successfully"
                    success = True
                else:
                    print "Deploy-Lock could not be locked"
                    success = False
                con.commit()
            elif action == "unlock":
                sql_query = "update rstore.deploy_statuses set running = false where running = true"
                if tier != "web":
                    sql_query = "update rstore.tier_deploy_statuses set running = false where running = true and tier = '%s'" % (tier)
                logging.debug(sql_query)
                cur = con.cursor()
                cur.execute(sql_query)
                if cur.rowcount > 0:
                    print "Deploy-Lock unlocked successfully"
                    success = True
                else:
                    print "Deploy-Lock could not be unlocked"
                    success = False
                con.commit()
        except MySQLdb.Error, e:
            logging.error("Error %d: %s" % (e.args[0], e.args[1]))
        finally:
            if con:
                con.close()
        return success
    
    def checkport(self, port):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        result = s.connect_ex(('127.0.0.1', port))
        s.close()

        return result != 0


