"""
Unit tests for "new" (normalized schema) Python database
loader prototype.
"""
__author__ = 'Dan Gunter dkgunter@lbl.gov'
__rcsid__ = '$Id: testLoader.py 914 2008-08-12 16:02:09Z dang $'

import logging
import os
import StringIO
import sys
import tempfile
import time
import unittest
import testBase
#
from netlogger.analysis import loader, schema
from netlogger.nllog import TRACE
from netlogger.parsers.base import NLFastParser

G1, G2 = ('e363842a-a90f-11dc-a6a0-001b63926e0d',
          'f21c451a-a90f-11dc-b83a-001b63926e0d')
LOG_4 = [
{ 'ts':1.0, 'event':'A.start', 'guid':G1 },
{ 'ts':2.0, 'event':'B.start', 'guid':G2, 'parent.guid':G1 },
{ 'ts':3.0, 'event':'B.end', 'guid':G2, 'status': 0 },
{ 'ts':4.0, 'event':'A.end', 'guid':G1, 'status':0 },
]

LONG_STRING = """Description: Error code: 201 Cause: org.globus.exec.generated.FaultType: Error code: 201 caused by [0: org.oasis.wsrf.faults.BaseFaultType: Script stderr: Sorry, user globus is not allowed to execute '/opt/globus-dev/libexec/globus-gridmap-and-execute /opt/globus-dev/libexec/globus-job-manager-script.pl -m fork -f /opt/globus-dev/tmp/gram_job_mgr5874.tmp -c cache_cleanup' as kjackson on strohs.lbl.gov.]"""

loader.activateLogging()
# For detailed debugging:
#h = logging.StreamHandler()
#h.setLevel(TRACE)
#loader.log.addHandler(h)
#loader.log.setLevel(TRACE)
        
class TestCase(testBase.BaseTestCase):

    def testProcessSQLite(self):
        self._sqliteConnect()
        self._load4()
        c = self.conn.cursor()
        c.execute("select id, name from event")
        n = 0
        for i, row in enumerate(c.fetchall()):
            n += 1
            # check primary key counter
            t_i = row[0]
            self.assertEqual(t_i, i+1)
            # check event type
            etype = row[1]
            event = LOG_4[i]['event']
            event_stripped = '.'.join(event.split('.')[:-1])
            self.assertEqual(etype, event_stripped)
        self.assert_(n == 4, "not enough events, wanted %d got %d" %
                     (4, n))

    def testQuotes(self):
        import sqlite3
        mystring = StringIO.StringIO()
        db = loader.TestDB(output=mystring, batch=10, quote_escape="'")
        db.insert("my_table", ("value",), (LONG_STRING,))
        db.flush()
        stmt = mystring.getvalue()
        c = sqlite3.connect(":memory:")
        c.execute("create table my_table (value varchar(1024))")
        try:
            self.debug_("statement: <<%s>>" % stmt)
            c.execute(stmt)
        except sqlite3.Error, E:
            self.fail("failed to insert long string: %s" % E)

    def testUnique(self):
        """Test whether duplicate inserts end up adding duplicate values.
        """
        self._sqliteConnect()
        self._load4()
        self._load4()
        c = self.conn.cursor()
        c.execute("select id from event;")
        n = 0
        for _ in c.fetchall():
            n += 1
        self.assert_(n > 3, "only %d events found, 4 expected" % n)
        self.assert_(n < 5, "%d events found, 4 expected: "
                     "duplicate event should not be allowed" % n)

    def testEventHash(self):
        """Hash of two very similar events is different
        """
        estr = ["ts=2008-07-08T00:11:48.404-05:00 event=pegasus.invocation level=Info status=0 nsignals=0 workflow.id=CyberShake_USC_4 host=tg-c490.ncsa.teragrid.org user=tera3d duration=122.526000 type=scec::seismogram_synthesis:1.0",
                "ts=2008-07-08T00:11:48.405-05:00 event=pegasus.invocation level=Info status=0 nsignals=0 workflow.id=CyberShake_USC_4 host=tg-c490.ncsa.teragrid.org user=tera3d duration=122.526000 type=scec::seismogram_synthesis:1.0"]
        parser = NLFastParser()
        edict = map(parser.parseLine, estr)
        h0, h1 = map(loader.eventHash, edict)
        self.failIf(h0 == h1, "hashes of different events are the same") 
        
    def testTruncate(self):
        """Use names/values that are too long, make sure they get truncated
        """
        self._sqliteConnect()
        x300 = 'x' * 300
        e = {'ts' : 0, 'event': x300, 'dn' : x300, 
             'guid' : x300, x300 + '.id' : x300, 
             'foo' : x300, x300 : x300 }
        self.loader.load(e)
        self.loader.flush()
        # check values in db
        schema_file = self.loader.conn._findSchema(None)
        db_stmt = schema.DBStatements(schema_file, type='sqlite')
        max_name_len = db_stmt.getConstant(db_stmt.NAME_MAX)
        max_value_len = db_stmt.getConstant(db_stmt.VALUE_MAX)
        xnm = 'x' * max_name_len
        xval = 'x' * max_value_len
        trunc_nm_err = "badly truncated %s name: '%s'"
        trunc_val_err = "badly truncated value for %s: '%s'"
        c = self.conn.cursor()    
        # check event name
        c.execute("select name from event")
        v = c.fetchone()[0]        
        self.assert_(v == xnm, trunc_nm_err % ('event',v))
        # check dn
        c.execute("select value from dn")
        v = c.fetchone()[0]
        self.assert_(v == xval, trunc_val_err % ('dn',v))
        # check guid
        c.execute("select value from ident where name = 'guid'")
        v = c.fetchone()[0]
        self.assert_(v == xval, trunc_val_err % ('guid',v))
        # check long id
        c.execute("select value from ident where name = '%s'" % xnm)
        vlist = c.fetchone()
        self.assert_(len(vlist), "cannot find long id")
        v = vlist[0]
        self.assert_(v == xval, trunc_val_err % ('long id',v))
        # check attr val
        c.execute("select value from attr where name = 'foo'")
        v = c.fetchone()[0]
        self.assert_(v == xval, trunc_val_err % ('attr foo',v))
        # check attr name/val
        c.execute("select name, value from attr where name = '%s'" % xnm)
        vlist = c.fetchone()
        self.assert_(len(vlist), "cannot find long attr name")
        v = vlist[0]
        self.assert_(v == xval, trunc_val_err % ('long attr',v))

    def _sqliteConnect(self):
        self.module = loader.sqlite
        self.tmpfile = tempfile.NamedTemporaryFile()
        self.conn = loader.DB(db_module=self.module, dsn=self.tmpfile.name,
                              create=1)
        self.loader = loader.LoaderFactory(self.conn).new()
        
    def _load4(self):
        self.debug_("loading: %s" % LOG_4)
        map(self.loader.load, LOG_4)
        self.loader.flush()

# Boilerplate to run the tests
def suite(): 
    return testBase.suite(TestCase)
if __name__ == '__main__':
    testBase.main()
