import json
from subprocess import check_call
import random
import socket
import logging

from qds_ops.tparty.knife import Knife
from qds_ops.utils.ssh_tunnel import SSHTunnel
import MySQLdb


class RDS():
    def __init__(self):
        return

    def configure_parsers(self, subparser):
        parser = subparser.add_parser("rds",
                                      help="Commands to monitor and control RDS Mysql servers")
        rds_subparser = parser.add_subparsers()

        # hostnames
        hn_parser = rds_subparser.add_parser("hostnames",
                                             help="List hostnames of the primary Mysql server and read replica if any")
        hn_parser.set_defaults(func=self.hostnames)

        #connect
        cn_parser = rds_subparser.add_parser("connect",
                                             help="Connect to the primary or replica Mysql server")
        cn_parser.add_argument("-a", "--autocomplete", dest="autocomplete",
                               action="store_true",
                               help="enable tab-completion of table and column names (with slower startup)")
        cn_parser.add_argument("-r", "--replica", dest="replica",
                               action="store_true", help="Login into read replica")
        cn_parser.add_argument("-t", "--tunnel", dest="tunnel",
                               action="store_true", help="Use a tunnel through a web node")
        cn_parser.set_defaults(func=self.connect)

        #tunnel
        tn_parser = rds_subparser.add_parser("tunnel",
                                             help="Connect to the primary or replica Mysql server through a tunnel on a web node")
        tn_parser.add_argument("-a", "--autocomplete", dest="autocomplete",
                               action="store_true",
                               help="enable tab-completion of table and column names (with slower startup)")
        tn_parser.add_argument("-r", "--replica", dest="replica",
                               action="store_true", help="Login into read replica")
        tn_parser.set_defaults(func=self.tunnel)

    def hostnames(self, args):
        knife = Knife()
        creds = knife.get_credentials(args.environment)
        output = {"primary": creds["host"], "read_replica": creds["rr_host"]}

        print json.dumps(output, indent=4, sort_keys=True)

    @staticmethod
    def checkport(port):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        result = s.connect_ex(('127.0.0.1', port))
        s.close()

        return result != 0

    def connect(self, args):
        knife = Knife()
        creds = knife.get_credentials(args.environment)
        hostname = ""
        if args.replica:
            hostname = creds["rr_host"]
        else:
            hostname = creds["host"]
        port = 3306
        pemfile = None

        if args.tunnel:
            pemfile = knife.get_pemfile(args.environment)
            list = knife.search(args.environment, "web")
            port = random.randrange(1, 1000) + 33000
            counter = 1
            while not self.checkport(port) and counter < 5:
                port = random.randrange(1, 1000) + 33000
                counter += 1

            if counter > 5:
                raise RuntimeError("Could not find a free port")
            tunnel = SSHTunnel("ec2-user", list[0], hostname, pemfile, 22, 3306, port)
            hostname = "127.0.0.1"

        mysql_cmd = ["mysql", "-u", creds["username"],
                     "-p%s" % creds["password"], "-h", hostname, "-P", str(port), "--prompt",
                     "%s>\_" % args.environment]
        logging.debug(" ".join(mysql_cmd))

        if not args.autocomplete:
            mysql_cmd.append("-A")

        mysql = check_call(mysql_cmd)


    def tunnel(self, args):
        args.tunnel = True
        self.connect(args)

    @staticmethod
    def run_query(env, tunnel, sql_query):
        knife = Knife()
        list = knife.search(env, "web")
        creds = knife.get_credentials(env)
        hostname = creds["host"]
        port = 3306
        # Setup a tunnel
        if tunnel:
            port = random.randrange(1, 1000) + 33000
            pemfile = knife.get_pemfile(env)
            counter = 1
            while not RDS.checkport(port) and counter < 5:
                port = random.randrange(1, 1000) + 33000
                counter += 1

            if counter > 5:
                raise RuntimeError("Could not find a free port")
            tunnel = SSHTunnel("ec2-user", list[0], hostname, pemfile, 22, 3306, port)
            hostname = "127.0.0.1"

        #Connect to DB and run query
        try:
            con = MySQLdb.connect(hostname, creds["username"], creds["password"], 'rstore', port);

            cur = con.cursor(MySQLdb.cursors.DictCursor)
            cur.execute(sql_query)
            rows = cur.fetchone()
            return rows
        except MySQLdb.Error, e:
            logging.error("Error %d: %s" % (e.args[0], e.args[1]))
