#!/usr/bin/env python
"""
Unittests for the database (loader).
"""
__author__ = 'Dan Gunter dkgunter@lbl.gov'
__rcsid__ = '$Id: testDatabase.py 1039 2008-09-15 22:44:44Z dang $'

import logging
import math
import random
import threading
import time
import unittest
#
import testBase
from netlogger import nllog
from netlogger.analysis import loader

def getEvents(i):
    t0 = 1209755749.0 + 2.123 * i
    guid = '%04de138-187c-11dd-9a7b-001b63926e0d' % i
    localid = '%d' % (100 + i, )
    dn = '/DC=org/DC=doegrids/OU=People/CN=Daniel Gunter 757118'
    events = [
        {'ts': t0,        'event' : 'outer.start', 'guid' : guid, },
        {'ts': t0 + 0.11, 'event' : 'inner.start', 'guid' : guid, 
         'local.id' : localid, 'DN' : dn },
        {'ts': t0 + 0.22, 'event' : 'sample1', 'local.id' : localid, 
         'text' : "This is a string of text. " * 10  },
        {'ts': t0 + 0.34, 'event' : 'inner.end', 'guid' : guid, 
         'local.id' : localid, 'status' : '0' },
        {'ts': t0 + 1.2,  'event' : 'outer.end', 'guid' : guid, 'status' : '0' },
        ]
    return events

def roundToUsec(v):
    return math.floor(v * 1e6 + 0.5) / 1e6

def pickRandomDB():
    s = ""
    for i in xrange(20):
        c = chr(random.randint(ord('a'), ord('z')))
        s += c
    return s

class TestCase(testBase.BaseTestCase):
    """Unit test cases.
    """
    LOAD_NUM = 200
    def setUp(self):
        random.seed(time.time())
        # at TRACE level, crank up program debugging
        if self.TRACE and isinstance(loader.log, nllog.NullLogger):
            loader.activateLogging("netlogger.loader")
            log = logging.getLogger("netlogger")
            h = logging.StreamHandler()
            h.setLevel(logging.DEBUG)
            log.addHandler(h)
            log.setLevel(logging.DEBUG)
        self.mysql_db = None
        # try to connect to a local MySQL database
        if loader.DB_MODULES['mysql']:
            mysql = loader.DB_MODULES['mysql']
            kw = dict(read_default_file='~/.my.cnf', db=pickRandomDB())
            # login
            try:
                self.mysql_db = loader.DB(db_module=mysql, conn_kw=kw,
                                          create=1)
            except (mysql.OperationalError, RuntimeError):
                self.debug_("cannot connect to MySQL, skipping those tests")
        self.pgsql_db = None
        # try to connect to a local PostgreSQL database
        # which uses 'ident' authentication (i.e. by user)
        self.debug_("connecting to postgres")
        if loader.DB_MODULES['postgres']:
            pgsql = loader.DB_MODULES['postgres']
            self._pgdb = pickRandomDB()
            kw = dict(db=self._pgdb)
            try:
                self.pgsql_db = loader.DB(db_module=pgsql, conn_kw=kw, create=1)
            except:
                raise
            self.debug_("connected to postgres db = %s" % self._pgdb)
        else:
            self.debug_("no postgres module found, skipping tests")

    def tearDown(self):
        if self.mysql_db:
            c = self.mysql_db.cursor()
            self.debug_("dropping MySQL database %s" % self.mysql_db.database)
            c.execute("drop database %s" % self.mysql_db.database)
        if self.pgsql_db:
            self.debug_("dropping PostgreSQL database %s" % self._pgdb)
            self.pgsql_db.close()
            pgsql = loader.DB_MODULES['postgres']
            self.pgsql_db = loader.DB(db_module=pgsql, conn_kw={'db':'postgres'})
            c = self.pgsql_db.cursor()
            conn = self.pgsql_db.conn
            lvl = conn.isolation_level
            conn.set_isolation_level(0)
            c.execute("drop database %s" % self._pgdb)
            conn.set_isolation_level(lvl)
            conn.close()

    def testLoad(self):
        """Load some data into all found databases
        """
        if self.mysql_db:
            self._testLoadMysql()
            c = self.mysql_db.cursor()
            self._dbCheck(c)
        if self.pgsql_db:
            self._testLoadPostgres()
            c = self.pgsql_db.cursor()
            self._dbCheck(c)

    def _testLoadMysql(self):
        """Load some data into the (MySQL) database.
        """
        # set loader factory
        factory = loader.LoaderFactory(self.mysql_db)
        # load up the data, multiple times
        for i in xrange(self.LOAD_NUM):
            events = getEvents(i)
            self.trace_("factory for events #%d: %s" % (i, events))
            ldr = factory.new()
            self.trace_("loading events #%d: %s" % (i, events))
            map(ldr.load, events)


    def _testLoadPostgres(self):
        """Load some data into the (PostgreSQL) database.
        """
        # set loader factory
        factory = loader.LoaderFactory(self.pgsql_db)
        # load up the data, multiple times
        for i in xrange(self.LOAD_NUM):
            events = getEvents(i)
            self.trace_("factory for events #%d: %s" % (i, events))
            ldr = factory.new()
            self.trace_("loading events #%d: %s" % (i, events))
            map(ldr.load, events)

    def _dbCheck(self, c):
        # check the data
        for i in xrange(self.LOAD_NUM):
            events = getEvents(i)
            events = events[:2] + events[3:] # remove event#2
            guid = events[0]['guid']
            c.execute("select e.id, e.time, e.name,e.startend from event as e"
                      " left join ident on e.id = ident.e_id"
                      " where ident.name = 'guid' and ident.value = '%s'" % 
                      guid)
            for e in events:
                ename = e['event']
                ename_stripped = '.'.join(ename.split('.')[:-1])
                row = c.fetchone()
                #self.debug_("num=%d event=%s %s" % (i,ename,row))
                self.assert_(row[0] is not None, "too few rows at %s" % ename)
                self.assertEqual(row[1], roundToUsec(e['ts']))
                self.assertEqual(row[2], ename_stripped)
                if ename.endswith('.start'):
                    self.assertEqual(row[3], 0)
                elif ename.endswith('.end'):
                    self.assertEqual(row[3], 1)
                else:
                    self.assertEqual(row[3], 2)
 
# Boilerplate to run the tests
def suite(): 
    return testBase.suite(TestCase)
if __name__ == '__main__':
    testBase.main()
