#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from optparse import OptionParser
from StringIO import StringIO

import cmd
import sys
import readline
import os
import re
import string
import time

try:
    import cql
except ImportError:
    sys.path.append(os.path.abspath(os.path.dirname(__file__)))
    import cql
from cql.cursor import _COUNT_DESCRIPTION, _VOID_DESCRIPTION

HISTORY = os.path.join(os.path.expanduser('~'), '.cqlsh')
CQLTYPES = ("bytes", "ascii", "utf8", "timeuuid", "uuid", "long", "int", "int32")

RED = "\033[1;31m%s\033[0m"
GREEN = "\033[1;32m%s\033[0m"
BLUE = "\033[1;34m%s\033[0m"
YELLOW = "\033[1;33m%s\033[0m"
CYAN = "\033[1;36m%s\033[0m"
MAGENTA = "\033[1;35m%s\033[0m"

def startswith(words, text):
    return [i for i in words if i.startswith(text)]

class Shell(cmd.Cmd):
    default_prompt  = "cqlsh> "
    continue_prompt = "   ... "

    def __init__(self, hostname, port, color=False, username=None,
            password=None):
        cmd.Cmd.__init__(self)
        self.conn = cql.connect(hostname, port, user=username, password=password)
        self.cursor = self.conn.cursor()

        if os.path.exists(HISTORY):
            readline.read_history_file(HISTORY)

        if sys.stdin.isatty():
            self.prompt = Shell.default_prompt
        else:
            self.prompt = ""
        
        self.statement = StringIO()
        self.color = color
        self.in_comment = False
        self.in_batch = False
    
    def reset_statement(self):
        self.set_prompt(Shell.default_prompt)
        self.statement.truncate(0)
        
    def get_statement(self, line):
        if self.in_comment:
            if "*/" in line:
                fragment = line[line.index("*/")+2:]
                if fragment.strip():
                    line = fragment
                    self.in_comment = False
                else:
                    self.in_comment = False
                    self.set_prompt(Shell.default_prompt)
                    return None
            else:
                return None
        
        if "/*" in line and (not self.in_comment):
            self.in_comment = True
            self.set_prompt(Shell.continue_prompt)
            if line.lstrip().index("/*") != 0:
                self.statement.write(line[:line.lstrip().index("/*")])
            return None
        
        if ["BEGIN", "BATCH"] == line.upper().split()[:2]:
            self.set_prompt(Shell.continue_prompt)
            self.in_batch = True
            self.statement.write("%s\n" % line)
            return None
        elif (["APPLY", "BATCH"] == line.upper().split()) and self.in_batch:
            self.in_batch = False
            self.statement.write("%s\n" % line)
            try:
                return self.statement.getvalue()
            finally:
                self.reset_statement()
        
        self.statement.write("%s\n" % line)
        
        # The end of a statement    
        if not line.endswith(";") or self.in_batch:
            self.set_prompt(Shell.continue_prompt)
            return None
        
        try:    
            return self.statement.getvalue()
        finally:
            self.reset_statement()

    def default(self, arg):
        def scrub_oneline_comments(s):
            res = re.sub(r'\/\*.*\*\/', '', s)
            res = re.sub(r'--.*$', '', res)
            return res
        
        input = scrub_oneline_comments(arg)
        if not input.strip(): return
        statement = self.get_statement(input)
        if not statement: return

        for i in range(1,4):
            try:
                self.cursor.execute(statement)
                break
            except cql.IntegrityError, err:
                self.printerr("Attempt #%d: %s" % (i, str(err)))
                time.sleep(1*i)

        if self.cursor.description is _VOID_DESCRIPTION:
            return
        elif self.cursor.description is _COUNT_DESCRIPTION:
            self.print_count_result()
        else:
            self.print_result()

    def print_count_result(self):
        if not self.cursor.result:
            return
        print 'count'
        print '-----'
        print self.cursor.result[0]
        self.printout("")

    def print_result(self):
        # first pass: see if we have a static column set
        last_description = None
        for row in self.cursor:
            if last_description is not None and self.cursor.description != last_description:
                static = False
                break
        else:
            static = True
        self.cursor._reset()

        if static:
            self.print_static_result()
        else:
            self.print_dynamic_result()
        self.printout("")

    def print_static_result(self):
        # first pass, get widths
        widths = defaultdict(lambda: 0)
        for row in self.cursor:
            for desc, value in zip(self.cursor.description, row):
                name = desc[0]
                widths[name] = max(widths[name], len(str(name)), len(str(value)))
        self.cursor._reset()
                
        # print header
        for desc in self.cursor.description:
            name = desc[0]
            width = widths[name]
            self.printout(" ", newline=False)
            self.printout(string.rjust(str(name), width), MAGENTA, False)
            self.printout(" |", newline=False)
        self.printout("")

        # print row data
        for row in self.cursor:
            for desc, value in zip(self.cursor.description, row):
                name = desc[0]
                width = widths[desc[0]]
                self.printout(" ", newline=False)
                self.printout(string.rjust(str(value), width), YELLOW, False)
                self.printout(" |", newline=False)
            self.printout("")

    def print_dynamic_result(self):
        for row in self.cursor:
            self.printout(" ", newline=False)
            for desc, value in zip(self.cursor.description, row):
                name = desc[0]
                self.printout(str(name), MAGENTA, False)
                self.printout(",", newline=False)
                self.printout(str(value), YELLOW, False)
                self.printout(" | ", newline=False)
            self.printout("")

    def emptyline(self):
        pass

    def complete_select(self, text, line, begidx, endidx):
        keywords = ('FIRST', 'REVERSED', 'FROM', 'WHERE', 'KEY')
        return startswith(keywords, text.upper())
    complete_SELECT = complete_select

    def complete_update(self, text, line, begidx, endidx):
        keywords = ('WHERE', 'KEY', 'SET')
        return startswith(keywords, text.upper())
    complete_UPDATE = complete_update

    def complete_create(self, text, line, begidx, endidx):
        words = line.split()
        if len(words) < 3:
            return startswith(['COLUMNFAMILY', 'KEYSPACE'], text.upper())

        common = ['WITH', 'AND']

        if words[1].upper() == 'COLUMNFAMILY':
            types = startswith(CQLTYPES, text)
            keywords = startswith(('KEY', 'PRIMARY'), text.upper())
            props = startswith(("comparator",
                                "comment",
                                "row_cache_size",
                                "key_cache_size",
                                "read_repair_chance",
                                "gc_grace_seconds",
                                "default_validation",
                                "min_compaction_threshold",
                                "max_compaction_threshold",
                                "row_cache_save_period_in_seconds",
                                "key_cache_save_period_in_seconds",
                                "memtable_flush_after_mins",
                                "memtable_throughput_in_mb",
                                "memtable_operations_in_millions",
                                "replicate_on_write"), text)
            return startswith(common, text.upper()) + types + keywords + props

        if words[1].upper() == 'KEYSPACE':
            props = ("replication_factor", "strategy_options", "strategy_class")
            return startswith(common, text.upper()) + startswith(props, text)
    complete_CREATE = complete_create

    def complete_drop(self, text, line, begidx, endidx):
        words = line.split()
        if len(words) < 3:
            return startswith(['COLUMNFAMILY', 'KEYSPACE'], text.upper())
    complete_DROP = complete_drop

    def completenames(self, text, *ignored):
        cmds = startswith(('USE', 'SELECT', 'UPDATE', 'DELETE', 'CREATE', 'DROP'),
                          text.upper())
        return cmd.Cmd.completenames(self, text, ignored) + cmds

    def set_prompt(self, prompt):
        if sys.stdin.isatty():
            self.prompt = prompt

    def do_EOF(self, arg):
        if sys.stdin.isatty(): print
        self.do_exit(None)

    def do_exit(self, arg):
        sys.exit()
    do_quit = do_exit

    def printout(self, text, color=None, newline=True, out=sys.stdout):
        if not color or not self.color:
            out.write(text)
        else:
            out.write(color % text)

        if newline:
            out.write("\n");

    def printerr(self, text, color=RED, newline=True):
        self.printout(text, color, newline, sys.stderr)

if __name__ == '__main__':
    parser = OptionParser(usage = "Usage: %prog [host [port]]")
    parser.add_option("-C",
                      "--color",
                      action="store_true",
                      help="Enable color output.")
    parser.add_option("-u", "--username", help="Authenticate as user.")
    parser.add_option("-p", "--password", help="Authenticate using password.")
    (options, arguments) = parser.parse_args()

    hostname = len(arguments) > 0 and arguments[0] or "localhost"

    if len(arguments) > 1:
        try:
            port = int(arguments[1])
        except ValueError:
            print >>sys.stderr, "%s is not a valid port number" % arguments[1]
            parser.print_help(file=sys.stderr)
            sys.exit(1)
    else:
        port = 9160


    shell = Shell(hostname,
                  port,
                  color=options.color,
                  username=options.username,
                  password=options.password)
    while True:
        try:
            shell.cmdloop()
        except SystemExit:
            readline.write_history_file(HISTORY)
            break
        except cql.Error, cqlerr:
            shell.printerr(str(cqlerr))
        except KeyboardInterrupt:
            shell.reset_statement()
            print
        except Exception, err:
            shell.printerr("Exception: %s" % err)
