#! /usr/bin/env python




import code




try:
    import readline
    import rlcompleter
    HAVE_READLINE = True
except ImportError:
    HAVE_READLINE = False




from pytools import cartesian_product, Record
PLOT_STYLES = [
        Record(dashes=dashes, color=color)
        for dashes, color in cartesian_product(
            [(), (12, 2), (4, 2),  (2,2), (2,8) ],
            ["blue", "green", "red", "magenta", "cyan"],
            )]




class RunDB(object):
    def __init__(self, db, interactive):
        self.db = db
        self.interactive = interactive

    def q(self, qry):
        return self.db.execute(self.magic_sql(qry))

    def magic_sql(self, qry):
        up_qry = qry.upper()
        if "FROM" in up_qry and not "$$" in up_qry:
            return qry

        magic_columns = set()

        def replace_magic_column(match):
            qty_name = match.group(1)
            magic_columns.add(qty_name)
            return "%s.value" % qty_name

        import re
        magic_column_re = re.compile(r"\$([a-zA-Z][A-Za-z0-9_]*)")
        qry = magic_column_re.sub(replace_magic_column, qry)

        other_clauses = ["UNION",  "INTERSECT", "EXCEPT", "WHERE", "GROUP",
                "HAVING", "ORDER", "LIMIT", ";"]

        from_clause = "from runs "
        last_tbl = None
        for tbl in magic_columns:
            if last_tbl is not None:
                addendum = " and %s.step = %s.step" % (last_tbl, tbl)
            else:
                addendum = ""

            from_clause += " inner join %s on (%s.run_id = runs.id%s) " % (
                    tbl, tbl, addendum)
            last_tbl = tbl
        
        if "$$" in qry:
            return qry.replace("$$"," %s " % from_clause)
        else:
            first_clause_idx = len(qry)+1
            up_qry = qry.upper()
            for clause in other_clauses:
                clause_match = re.search(r"\b%s\b" % clause, up_qry)
                if clause_match is not None and clause_match.start() < first_clause_idx:
                    first_clause_idx = clause_match.start()
            if first_clause_idx > len(qry):
                from_clause = " "+from_clause
            return (
                    qry[:first_clause_idx]
                    +from_clause
                    +qry[first_clause_idx:])

    def plot_cursor(self, cursor, **kwargs):
        from pylab import plot, show, legend


        if len(cursor.description) == 2:
            if not kwargs:
                style = PLOT_STYLES[0]
                kwargs["dashes"] = style.dashes
                kwargs["color"] = style.color

            x, y = zip(*list(cursor))
            plot(x, y, hold=True, **kwargs)
            if self.interactive:
                show()
        elif len(cursor.description) > 2:
            small_legend = kwargs.setdefault("small_legend", True)
            del kwargs["small_legend"]

            def do_plot():
                style = PLOT_STYLES[style_idx[0] % len(PLOT_STYLES)]
                kwargs["dashes"] = style.dashes
                kwargs["color"] = style.color
                kwargs["label"] = " ".join("%s:%s" % (column[0], value)
                        for column, value in zip(cursor.description[2:], last_rest))
                plot(x, y, hold=True, **kwargs)
                style_idx[0] += 1
                del x[:]
                del y[:]

            style_idx = [0]
            x = []
            y = []
            last_rest = None
            for row in cursor:
                row_tuple = tuple(row)
                row_rest = row_tuple[2:]

                if last_rest is None:
                    last_rest = row_rest

                if row_rest != last_rest:
                    do_plot()
                    last_rest = row_rest

                x.append(row_tuple[0])
                y.append(row_tuple[1])
            if x:
                do_plot()

            if small_legend:
                from matplotlib.font_manager import FontProperties
                legend(pad=0.04, prop=FontProperties(size=8), loc="best",
                        labelsep=0)
            if self.interactive:
                show()
        else:
            raise ValueError, "invalid number of columns"
        
    def print_cursor(self, cursor):
        from pytools import Table
        tbl = Table()
        tbl.add_row([column[0] for column in cursor.description])
        for row in cursor:
            tbl.add_row(row)
        print tbl




class RunalyzerConsole(code.InteractiveConsole,RunDB):
    def __init__(self, db):
        RunDB.__init__(self, db, interactive=True)

        symbols = {
                "__name__": "__console__",
                "__doc__": None,
                "db": db,
                "magic_sql": self.magic_sql,
                "q": self.q,
                "dbplot": self.plot_cursor,
                "dbprint": self.print_cursor,
                }
        code.InteractiveConsole.__init__(self, symbols)

        try:
            import pylab
            import matplotlib
            self.runsource("from pylab import *")
        except ImportError:
            pass

        if HAVE_READLINE:
            import os
            import atexit

            histfile = os.path.join(os.environ["HOME"], ".runalyzerhist")
            if os.access(histfile, os.R_OK):
                readline.read_history_file(histfile)
            atexit.register(readline.write_history_file, histfile)
            readline.parse_and_bind("tab: complete")

        self.last_push_result = False

    def push(self, cmdline):
        if cmdline.startswith("."):
            try:
                self.execute_magic(cmdline)
            except:
                import traceback
                traceback.print_exc()
        else:
            self.last_push_result = code.InteractiveConsole.push(self, cmdline)

        return self.last_push_result


    def execute_magic(self, cmdline):
        cmd_end = cmdline.find(" ")
        if cmd_end == -1:
            cmd = cmdline[1:]
            args = ""
        else:
            cmd = cmdline[1:cmd_end]
            args = cmdline[cmd_end+1:]

        if cmd == "help":
            print """
Commands:
 .help        show this help message
 .q SQL       execute MagicSQL query
 .runprops    show a list of run properties
 .quantities  show a list of time-dependent quantites

Plotting:
 .plot SQL    plot results of MagicSQL query
              result sets can be (x,y) or (x,y,descr1,descr2,...),
              in which case a new plot will be started for each
              tuple (descr1, descr2, ...)

MagicSQL:
    select $quantity where pred(feature)

Custom SQLite aggregates:
    stddev, var, norm1, norm2

Available Python functions:
    db: the SQLite database
    magic_sql(query_str): get MagicSQL query for query_str
    q(query_str): get db cursor for MagicSQL query_str
    dbplot(cursor): same as plot, but for cus
    dbprint(cursor): print result of cursor
"""
        elif cmd == "q":
            self.print_cursor(self.q(args))

        elif cmd == "runprops":
            cursor = self.db.execute("select * from runs")
            columns = [column[0] for column in cursor.description]
            columns.sort()
            for col in columns:
                print col
        elif cmd == "quantities":
            self.print_cursor(self.q("select * from quantities order by name"))
        elif cmd == "title":
            from pylab import title
            title(args)
        elif cmd == "plot":
            self.plot_cursor(self.db.execute(self.magic_sql(args)))
        else:
            print "invalid magic command"




# custom aggregates -----------------------------------------------------------
from pytools import VarianceAggregator
class Variance(VarianceAggregator):
    def __init__(self):
        VarianceAggregator.__init__(self, entire_pop=True)

class StdDeviation(Variance):
    def finalize(self):
        result = Variance.finalize(self)

        if result is None:
            return None
        else:
            from math import sqrt
            return sqrt(result)

class Norm1:
    def __init__(self):
        self.abs_sum = 0

    def step(self, value):
        self.abs_sum += abs(value)

    def finalize(self):
        return self.abs_sum

class Norm2:
    def __init__(self):
        self.square_sum = 0

    def step(self, value):
        self.square_sum += value**2

    def finalize(self):
        from math import sqrt
        return sqrt(self.square_sum)

def my_sprintf(format, arg):
    return format % arg




# main program ----------------------------------------------------------------
def main():
    import sys
    from optparse import OptionParser

    parser = OptionParser(usage="%prog DBFILE [SCRIPT]")
    options, args = parser.parse_args()

    if len(args) not in [1, 2]:
        parser.print_help()
        sys.exit(1)

    import sqlite3
    db = sqlite3.connect(args[0])
    db.create_aggregate("stddev", 1, StdDeviation)
    db.create_aggregate("var", 1, Variance)
    db.create_aggregate("norm1", 1, Norm1)
    db.create_aggregate("norm2", 1, Norm2)

    db.create_function("sprintf", 2, my_sprintf)

    if len(args) == 2:
        db = RunDB(db, interactive=False)
        symbols = {"db": db}
        execfile(args[1], symbols)
    else:
        cons = RunalyzerConsole(db)
        cons.interact("Runalyzer running on Python %s\n"
                "Copyright (c) Andreas Kloeckner 2008\n" 
                "Run .help to see help for magic commands" % sys.version)




if __name__ == "__main__":
    main()
