from __future__ import with_statement
"""
Module with common code for loading NetLogger logs into a database.
"""
__rcsid__ = "$Id: loader.py 1039 2008-09-15 22:44:44Z dang $"
__author__ = 'Dan Gunter'

from base64 import b64decode
import hashlib
import logging
from logging import DEBUG
import os
import re
import sys
import threading
import time
import warnings
#
from netlogger import nlapi
from netlogger import nllog
from netlogger.parsers.base import NLFastParser
from netlogger.analysis import schema

## DB imports

# SQLite
sqlite = None
try:
    # Python 2.5
    import sqlite3 as sqlite
except ImportError:
    try:
        # Python 2.4
        from pysqlite2 import dbapi2 as sqlite
    except ImportError:
        pass

# PostgreSQL
pgdb = None
try:
    import psycopg2 as pgdb
except ImportError:
    try:
        import pgdb
    except ImportError:
        pass
# MySQL
mysql = None
try:
    import MySQLdb as mysql
    from _mysql_exceptions import Warning
    # don't print certain MySQLdb warnings
    warnings.filterwarnings("error", "Table.*already exists", Warning)
    warnings.filterwarnings("ignore", "Converting column", Warning)
    # define a constant
    READ_DEFAULT_FILE = 'read_default_file'
except ImportError:
    pass

DB_MODULES = {
    'test' : True, # special
    'sqlite' : sqlite,
    'postgres' : pgdb,
    'mysql' : mysql }

DB_MODULE_NAMES = { }
for key, value in DB_MODULES.items():
    if key == 'test':
        continue
    DB_MODULE_NAMES[value] = key

AVAIL_DB = filter(lambda s: DB_MODULES[s], DB_MODULES.keys()) 

# Get directory of this file

def fileDir():
    return os.path.dirname(os.path.abspath(__file__))

# Logging
log = nllog.NullLogger()
def activateLogging(name=__name__):
    global log
    log = nllog.getLogger(name)

# Attribute name for the plain text
PLAIN_TEXT_ATTR = "text"
# Attribute name for the Base64-encoded text of the file
B64_TEXT_ATTR = "text64"

# Global stop/start
g_stop = False
def stop():
    global g_stop
    g_stop = True
    
def is_stopped():
    return g_stop

#
# Exceptions
#

class ProcessingError(Exception):
    pass

#
# Functions
#

def eventHash(d):
    """Calculate key-order-independent hash of event.
    """
    skeys = d.keys()
    skeys.sort()
    m5 = hashlib.md5()
    for k in skeys:
        m5.update(k)
        v = d[k]
        if isinstance(v, float):
            # guarantee 9 digits
            m5.update("%.9lf" % v)
        else:
            m5.update(str(v))
    # Smaller and cleaner than hexdigest()
    return m5.hexdigest()

#
# Classes
#

class Counter:
    """Wrap an integer so that it only supports a single
    atomic 'get and increment' operation.
    """
    def __init__(self, value):
        self.__v = value
        self._lock = threading.Lock()

    def incr(self):
        self._lock.acquire()
        v, self.__v = self.__v, self.__v + 1
        self._lock.release()
        return v

class BaseDB:
    """Abstract base class for databases.
    """
    def __init__(self, db, quote_escape=None):
        self.db = db
        # figure out quoting style
        if quote_escape:
            self.sq_esc = quote_escape
        elif db is sqlite:
            self.sq_esc = "'"
        else:
            self.sq_esc = "\\"

    def execute(self, stmt, *args):
        pass

    def close(self):
        pass

    def checkFlush(self):
        pass

    def getPlaceholders(self, attrs):
        """Come up with a comma-separated list of 
        placeholders with the same length as 'attrs'.
        This is the SQL-injection-unfriendly way of
        building query strings.
        """
        db = self.db
        if attrs[0] == '*':
            s = '*'
        else:
            params = [ ]
            for i, a in enumerate(attrs):
                if db.paramstyle == 'qmark':
                    params.append('?')
                elif db.paramstyle == 'numeric':
                    params.append(':%d' % i+1)
                elif db.paramstyle == 'named':
                    params.append(':%s' % a)
                elif db.paramstyle == 'format':
                    params.append('%s')
                elif db.paramstyle == 'pyformat':
                    params.append('%%(%s)s' % a)
            s = ', '.join(params)
        return s

    def fixQuotes(self, s, sq="'"):
        """Make sure that there are no un-escaped single quotes in
        the string. The quote character can be changed by
        passing a different value for 'sq'; however open/close quotes
        like braces are not supported.
        """
        _e = self.sq_esc # save typing
        # special cases for short strings
        if len(s) == 0:
            return ""
        if len(s) == 1:
            if s[0] == sq:
                return _e + sq
            else:
                return s
        # do not escape the first and last single-quote if present
        if s[0] == sq and s[-1] == sq:
            sq_prev, sq_last = 0, len(s) - 1
        else:
            sq_prev, sq_last = -1, len(s)
        # note: 'sq' stands for 'single quote'
        # loop through each occurrence of a single quote
        while 1:
            sq_pos = s.find(sq, sq_prev + 1, sq_last)
            if sq_pos == -1:
                break
            # count number of preceding escapes
            i = sq_pos - 1
            for i in xrange(sq_pos - 1, sq_prev + 1, -1):
                if s[i] != _e:
                    break
            num_bs = sq_pos - i - 1
            if (num_bs % 2) == 0:
                s = s[:sq_pos] + _e + s[sq_pos:]
                sq_pos += 1
            sq_prev = sq_pos
        return s

    def getMaxNameLen(self):
        """Subclasses will override this to
        return the maximum length of a name column.
        """
        pass

    def getMaxValueLen(self):
        """Subclasses will override this to
        return the maximum length of a value column.
        """
        pass

class BatchedInsertDB(BaseDB):
    """Database with ability to insert many values at once.

    If using MySQL, put multiple rows into a single statement, e.g.:
    mysql> INSERT INTO tbl_name (a,b,c) VALUES(1,2,3),(4,5,6);

    Otherwise, use one insert statement per tuple, e.g.:
        sqlite> INSERT INTO tbl_name (a,b,c) VALUES(1,2,3);
        sqlite> INSERT INTO tbl_name (a,b,c) VALUES(4,5,6);
    """
    # Flush interval, in seconds, when no data is coming in
    FLUSH_INTERVAL = 1.0

    def __init__(self, db, batch=50, **kw):
        BaseDB.__init__(self, db, **kw)
        self._sz = max(batch,3)
        self._batches = { }
        self._last_flush = time.time()
        self._pending = 0 # pending number of inserts

    def execute(self, stmt, args=()):
        """Execute 'stmt' with arguments 'args'.

        Return 0 for success or -1 for an integrity error
        """
        if self._cursor is None:
            self._cursor = self.conn.cursor()
        if log.isEnabledFor(logging.DEBUG):
            log.trace("execute.start", stmt=stmt, args=args, 
                      cursor=self._cursor)
        integrity = 0
        try:
            self._cursor.execute(stmt, args)
        except self.db.IntegrityError, E:
            integrity = -1
            log.warn("execute.integrity.error", msg=str(E), statemenmt=stmt, stmt_args=args)
        if log.isEnabledFor(logging.DEBUG):
            log.trace("execute.end", cursor=self._cursor, 
                      status=integrity)
        return integrity

    def checkFlush(self):
        """Check if current batches should be flushed out.
        """
        t = time.time()
        dt = t - self._last_flush
        if dt > self.FLUSH_INTERVAL:
            self.flush()
            self._last_flush = t

    def insert(self, tbl, fields, values):
        log.debug("BatchedInsert.insert.start", table=tbl, 
                  num__fields=len(fields))
        if self._sz == 1: # special case
            self._execBatch(tbl, fields, [values])
        else:
            if not self._batches.has_key(tbl):
                self._batches[tbl] = (list(fields), [values])
            else:
                fields, values_list = self._batches[tbl]
                values_list.append(values)
        # Update number of pending inserts and flush
        # if the number exceeds the maximum
        self._pending = self._pending + 1
        if self._pending == self._sz:
            self.flush()
        log.debug("BatchedInsert.insert.end", table=tbl,
                  num__fields=len(fields), status=0)

    def flush(self):
        log.debug("BatchedInsert.flush.start", 
                  num__tables=len(self._batches.keys()))
        if self._batches.has_key('event'):
            # always do this one first
            self._flushTable('event')
        for tbl in self._batches.keys():
            self._flushTable(tbl)
        self._pending = 0
        log.debug("BatchedInsert.flush.end", status=0)

    def _flushTable(self, tbl):
        """Flush a single table and reset its cached batch
        of values to be empty
        """
        fields, values_list = self._batches[tbl]
        if values_list:
            self._execBatch(tbl, fields, values_list)
            self._batches[tbl] = (fields, [])

    def _execBatch(self, tbl, fields, values_list):
        if log.isEnabledFor(logging.DEBUG):
            log.debug("BatchedInsert._execBatch.start", table=tbl)
        stmt = "INSERT INTO %s " % tbl
        field_placeholders = ','.join(fields)
        stmt += " (" + field_placeholders + ") VALUES "
        fmt_codes = [self._getFormatCode(v) for v in values_list[0]]
        value_tuples = [ ]
        for values in values_list:
            value_tuple = [ ]
            for i, v in enumerate(values):
                value = fmt_codes[i] % v
                if '%s' in fmt_codes[i]:
                    value = self.fixQuotes(value)
                value_tuple.append(value)
            value_tuple_str = '(' + ','.join(value_tuple) + ')'
            value_tuples.append(value_tuple_str)
        if log.isEnabledFor(logging.DEBUG):
            log.debug("BatchedInsert._execBatch.execute.start", 
                      num__values=len(value_tuples))
        # currently only MySQL does multiple-statement inserts
        multiple = (self.db == mysql)
        if multiple:
            # optimization: all in one statement
            num_integrity_errors = 0 # for logging
            insert_stmt = stmt + ','.join(value_tuples)
            #print '@@', insert_stmt
            integrity_error = self.execute(insert_stmt)
            if integrity_error:
                # On integrity error, retry rest of batch one by one
                # since, sadly, we can't know which one failed.
                num_integrity_errors += 1
                for vt in value_tuples[1:]:
                    #print '@@retry',vt
                    integrity_error = self.execute(stmt + vt)
                    if integrity_error == -1:
                        num_integrity_errors += 1
            if num_integrity_errors > 0:
                log.warn("BatchedInsert.integrityErrors", 
                         n=num_integrity_errors, table=tbl) 
        else:
            # default behavior: multiple statements
            for vt in value_tuples:
                self.execute(stmt + vt)
        if log.isEnabledFor(logging.DEBUG):
            log.debug("BatchedInsert._execBatch.execute.end", status=0)
        if log.isEnabledFor(logging.DEBUG):
            log.debug("BatchedInsert._execBatch.end", table=tbl, status=0)

    def _getFormatCode(self, v):
        if isinstance(v, int):
            code = "%d"
        elif isinstance(v, float):
            code = "%lf"
        else:
            # default to quoted string
            code = "'%s'"
        return code

class TestDB(BatchedInsertDB):
    """Database used for testing.
    """
    def __init__(self, output=sys.stdout, batch=50, **kw):
        self.output = output
        BatchedInsertDB.__init__(self, None, batch=batch, **kw)

    def execute(self, stmt, args=()):
        stmt = stmt.replace('?', '%s')
        self.output.write(stmt % args)
        self.output.write(';\n')

    def flush(self):
        self.output.flush()

    def cursor(self):
        return self

    def fetchone(self):
        return [ ]

    def fetchmany(self, howmany):
        return [ ]

    def _quote(self, s):
        if s is None:
            return ''
        if hasattr(s, 'lower'):
            return "'" + s + "'"
        if isinstance(s, float):
            return "%lf" % s
        return "%d" % s

    def insert(self, tbl, fields, values):
        field_str = ','.join(map(self._quote, fields))
        value_str = ','.join(map(self.fixQuotes, map(self._quote, values)))
        self.execute("INSERT INTO ? (?) VALUES(?)", 
                     (tbl, field_str, value_str))

    def getMaxNameLen(self):
        return 255

    def getMaxValueLen(self):
        return 255

    def close(self):
        self.output.close()

class DB(BatchedInsertDB):
    """Main database class.
    Can represent a MySQL, PostgreSQL, or sqlite database.
    """

    def __init__(self, db_module=sqlite, dsn=None, batch=100, create=0,
                 unique=True, 
                 schema_file=None, schema_init_keys=[],
                 schema_finalize_keys=[], conn_kw={}, **kw):
        """Connect to DB and init relevant data structures.
        The create flag is: 0=no, 1=yes, 2=yes, after dropping.
        """
        BatchedInsertDB.__init__(self, db_module, batch=batch, **kw)
        self.dsn = dsn
        self._dbstmt = None
        self.database = None
        for k in ('db', 'database'):
            if conn_kw.has_key(k):
                self.database = conn_kw[k]
                del conn_kw[k]
                break
        # if the read_default_file parameter is not given, look
        # in default location for it
        if db_module == mysql and not conn_kw.has_key(READ_DEFAULT_FILE):
            default_path = os.path.expanduser("~/.my.cnf")
            if os.path.isfile(default_path):
                conn_kw[READ_DEFAULT_FILE] = default_path
        # if no db param and MySQL, also look in ~/.my.cnf or
        # whichever file the user gave with read_default_file
        if self.database is None and db_module == mysql and \
                conn_kw.has_key(READ_DEFAULT_FILE):
            import ConfigParser
            p = ConfigParser.SafeConfigParser()
            p.read(conn_kw[READ_DEFAULT_FILE])
            if p.has_option('client', 'database'):
                self.database = p.get('client', 'database')
        if self.database is None and db_module is not sqlite:
            if db_module == mysql:
                my_cnf = conn_kw.get(READ_DEFAULT_FILE)
                if my_cnf is None:
                    raise ValueError("database not specified and "
                                     "no MySQL configuration file")
                else:
                    raise ValueError("database not specified and "
                                     "not found in '%s'" % my_cnf)
            else:
                raise ValueError("database not specified")
        self.conn_kw = conn_kw
        self.conn = None
        # Init the schema-statement object
        schema_file = self._findSchema(schema_file)
        dbtype = DB_MODULE_NAMES[self.db]
        self._dbstmt = schema.DBStatements(schema_file, type=dbtype)
        # Connect to DB and initialize it
        try:
            self._connect(create, unique, schema_init_keys)
        except self.db.Error, E:
            raise RuntimeError("while connecting to '%s': %s" % (self.dsn, E))
        self._cursor = None

    def _findSchema(self, path):
        """Find the schema configuration file.
        Raise a RuntimeError if it can't be opened.
        """
        if not path:
            my_dir = fileDir()
            path = os.path.join(my_dir, "schema.conf")
        try:
            open(path, 'r')
        except OSError, E:
            raise RuntimeError("Cannot open schema file '%s' (%s)" % (
                    path, E))
        return path

    def _pg2ConnArg(self):
        kwlist = ["%s='%s'" % (k,v) for k,v in self.conn_kw.items()]
        return ' '.join(kwlist)

    def _pg2Isolate(self):
        self._isolvl = self.conn.isolation_level
        self.conn.set_isolation_level(0)

    def _pg2Deisolate(self):
        self.conn.set_isolation_level(self._isolvl)

    def _isolate(self):
        if self.db == pgdb:
            self._pg2Isolate()

    def _deisolate(self):
        if self.db == pgdb:
            self._pg2Deisolate()

    def _connect(self, create, unique, init_keys):
        # XXX: This is too long and should be refactored
        if self.db == pgdb: # psyco-db is a bit crazy..
            if self.database:
                database = self.database
            else:
                database = "NO_SUCH_DB" # should fail
            if self.dsn:
                self.conn_kw['host'] = self.dsn
            if create:
                # connect to 'dummy' database
                self.conn_kw['dbname'] = 'postgres'
            else:
                self.conn_kw['dbname'] = database
            self.conn = self.db.connect(self._pg2ConnArg())
        else:
            if self.db == sqlite:
                self.conn_kw['isolation_level'] = None
                if create & 2:
                    # drop database by removing file
                    log.info("drop.database.sqlite.start", file=self.dsn)
                    fexist = 1
                    try:
                        os.unlink(self.dsn)
                    except OSError, E:
                        fexist = 0
                    log.info("drop.database.sqlite.end", file=self.dsn, status=0, removed=fexist)
                # let's not do this:
                # create = 1 # always try to 'create' sqlite DB
            log.info("DB.connect.start", conn__host=self.dsn, 
                     param=self.conn_kw)
            if self.dsn is None:
                self.conn = self.db.connect(**self.conn_kw)
            else:
                self.conn = self.db.connect(self.dsn, **self.conn_kw)
            log.info("DB.connect.end", conn__host=self.dsn,  status=0)
        # use database, optionally creating and even dropping it first
        c = self.conn.cursor()
        if self.db is not sqlite:
            if create > 0:
                # drop database
                if create & 2: # bit indicating 'drop'
                    with nllog.logged(log, "drop.database"):
                        # if DB doesn't exist, log warning and continue
                        try:
                            self._isolate()
                            if self.db is pgdb:
                                cmd = "drop database if exists %s" % self.database
                            else:
                                cmd = "drop database %s" % self.database
                            c.execute(cmd)
                            self._deisolate()
                        except self.db.Error, E:
                            log.warn("drop.database.end", 
                                     dbname=self.database, status=-2, 
                                     msg="cannot drop database: %s" % E)
                # create database
                with nllog.logged(log, "create.database", 
                                  dbname=self.database):
                    if self.db is pgdb:
                        isolvl = self.conn.isolation_level
                        self.conn.set_isolation_level(0)
                    c.execute("create database %s" % self.database)
                    if self.db is pgdb:
                        self.conn.set_isolation_level(isolvl)
                        # re-connect to database we just created
                        self.conn.close()
                        self.conn_kw['dbname'] = database
                        self.conn = self.db.connect(self._pg2ConnArg())
                        c = self.conn.cursor()
            # use database
            with nllog.logged(log, "use.database"):
                if self.db is pgdb:
                    # reconnect
                    self.conn_kw['dbname'] = self.database
                    kwlist = ["%s='%s'" % (k,v) for k,v in self.conn_kw.items()]
                    self.conn = self.db.connect(' '.join(kwlist))
                else:
                    c.execute("use %s" % self.database)
        if create:
            with nllog.logged(log, "init.tables", dbname=self.database):
                self._dbstmt.setExecuteFunc(c.execute)
                self._dbstmt.init(init_keys)
                if self.db is pgdb:
                    c.execute("commit;")
                c.close()

    def cursor(self):
        return self.conn.cursor()


    def close(self):
        log.debug("DB.close.start")
        self.flush()
        with nllog.logged(log, "finalize.tables", dbname=self.database):
            keywords = () # XXX: add_index, compress..
            c = self.conn.cursor()
            self._dbstmt.setExecuteFunc(c.execute)
            self._dbstmt.finalize(keywords)
            c.close()
        log.debug("DB.close.end", status=0)

    def getMaxNameLen(self):
        """Return maximum length for 'name' columns
        """
        return self._dbstmt.getConstant(self._dbstmt.NAME_MAX)

    def getMaxValueLen(self):
        """Return maximum length for string value columns
        """
        return self._dbstmt.getConstant(self._dbstmt.NAME_MAX)

class LoaderFactory:
    """Make Loaders using our special little schema, sharing the
    same set of in-memory mappings from distinct keys to identifiers
    (encapsulated by the Index class).
    """
    def __init__(self, conn):
        self.conn = conn
        self.eid = Counter(self._getNextEventID())

    def new(self):
        """Instantiate new Loader instance.
        """
        return Loader(self.conn, self.eid)

    def _getNextEventID(self):
        """Get next available event id.
        """
        c = self.conn.cursor()
        c.execute("select MAX(id) from event")
        r = c.fetchone()
        if not r or r[0] is None:
            return 1
        else:
            return r[0] + 1
        
class Loader:
    """Load NetLogger events into a database.
    """
    debug = False
    
    def __init__(self, conn, eid, **kw):    
        self.conn = conn
        self._nmax = conn.getMaxNameLen()
        self._vmax = conn.getMaxValueLen()
        for n,v in kw.items():
            setattr(self, n, v)
        self._N = 0
        self._eid = eid
        self._calc_hash = True

    def load(self, d):
        """Load a single event record, represented as a dictionary.

        Will raise a ValueError if the event is invalid, specifically:
          * if it's missing the 'ts' keyword
          * if it's missing the 'event' keyword
        """
        if log.isEnabledFor(logging.DEBUG):
            log.debug("load.start", sz=len(d), dict__id=id(d))
        try:        
            timestamp = d['ts']
            event = d['event']
        except KeyError:
            raise ValueError("event is missing 'ts' and/or 'event' keyword")
        # severity
        severity = d.get('level', 'Info')
        if not isinstance(severity, int):
            try:
                severity = nlapi.Level.getLevel(severity.upper())
            except ValueError:
                severity = 100
        # constrain severity to 0..255
        if not 0 <= severity <= 255:
            severity = (255, 0)[severity < 0]
        # event type
        if event.endswith('.start'):
            se_val = '0'
            se_event = event[:-6]
        elif event.endswith('.end'):
            se_val = '1'
            se_event = event[:-4]
        else:
            se_val = '2'
            se_event = event
        # ~ event
        eid = '%d' % self._eid.incr()
        if self._calc_hash:
            _hash = eventHash(d)
        else:
            _hash = eid
        if log.isEnabledFor(logging.DEBUG):
            log.debug("event.insert.start", id=eid)
        self._insert('event', ('id', 'hash', 'time', 'name',
                               'startend', 'severity'),
                     (eid,  _hash, timestamp, self._vtrunc(se_event),
                      se_val, severity))
        if log.isEnabledFor(logging.DEBUG):
            log.debug("event.insert.end", id=eid)
        # exclude these from being added to attribute tables
        exclude = dict.fromkeys(('ts','event','level', 'dn', 'DN', 
                                 B64_TEXT_ATTR, PLAIN_TEXT_ATTR), True)
        # ~ dn
        dnval = d.get('DN', d.get('dn', None))
        if dnval:
            self._insert('dn', ('e_id', 'value'),
                         (eid, self._vtrunc(dnval)))
        # ~ text
        encoded_txt = d.get(B64_TEXT_ATTR, None)
        if encoded_txt is not None:
            decoded_txt = b64decode(encoded_txt)
            self._insert('text', ('e_id', 'value'),
                         (eid, decoded_txt))
        plain_txt = d.get(PLAIN_TEXT_ATTR, None)
        if plain_txt is not None:
            self._insert('text', ('e_id', 'value'),
                         (eid, plain_txt))
        # ~ ident
        for k, v in d.items():
            if exclude.has_key(k): continue
            if k == 'guid' or k == 'id':
                relationship = k
            elif k.endswith('.id'):
                relationship = k[:-3]
            else:
                continue
            self._insert('ident', ('e_id', 'name', 'value'), 
                         (eid, self._ntrunc(relationship), self._vtrunc(v)))
            exclude[k] = 1
        # ~attr: all other attributes
        for k,v in d.items():
            if exclude.has_key(k): continue
            self._insert('attr', ('e_id', 'name', 'value'), 
                         (eid, self._ntrunc(k), self._vtrunc(v)))
        self._N += 1
        if self._N % 1000 == 0:
            log.debug("events.processed", num=self._N)
        if log.isEnabledFor(logging.DEBUG):
            log.debug("load.end", event__nm=d['event'], dict__id=id(d))

    def checkFlush(self):
        """See BatchedDB.checkFlush()
        """
        self.conn.checkFlush()

    def flush(self):
        log.debug("flush.start")
        self.conn.flush()
        log.debug("flush.end", status=0)

    def close(self):
        log.debug("close.start")
        self.conn.close()
        log.debug("close.end", status=0)

    def setNoHash(self):
        """Turn off calculation of a hash value
        """
        self._calc_hash = False

    def _insert(self, *args):
        if log.isEnabledFor(DEBUG):
            log.debug("insert", table=args[0], fields=args[1], values=args[2])
        self.conn.insert(*args)

    def _ntrunc(self, s):
        if not isinstance(s, str) or len(s) <=  self._nmax:
            return s
        log.warn("truncate.name", len=len(s), maxlen=self._nmax,
                 value=s)
        return s[:self._nmax]

    def _vtrunc(self, s):
        if  not isinstance(s, str) or len(s) <= self._vmax:
            return s
        log.warn("truncate.value", len=len(s), maxlen=self._vmax,
                 value=s)
        return s[:self._vmax]
