#!/usr/bin/env python
"""
Unittests for analysis/modules/mongodb.py

If the import of mongodb fails, or the initial connect fails, this 
will just stop with an error.
"""
__author__ = 'Dan Gunter dkgunter@lbl.gov'
__rcsid__ = '$Id: testAnalyzeMongo.py 25126 2010-08-16 20:59:57Z dang $'

# Standard lib
import re
import time
# Third-party
from pymongo.connection import Connection
# NetLogger
from netlogger.tests import shared
import unittest
from netlogger.analysis.modules import mongodb

class TestCase(shared.BaseTestCase):

    MONGO_HOST='localhost'
    MONGO_PORT=27017
    DBNAME = 'nl_test_db'
    COLLNAME = 'nl_test_collection'

    def setUp(self):
        shared.BaseTestCase.setUp(self)
        self._clearDB()

    def tearDown(self):
        #self._clearDB()
        pass

    def _clearDB(self):
        """Drop the database.
        """
        Connection(host=self.MONGO_HOST, port=self.MONGO_PORT)\
            .drop_database(self.DBNAME)

    def _create(self, **kw):
        """Create Analyzer instance.
        """
        a = mongodb.Analyzer(host=self.MONGO_HOST,
                             port=self.MONGO_PORT,
                             database=self.DBNAME,
                             collection=self.COLLNAME,
                             **kw)
        return a

    def _connect_to_db(self):
        "Connect to mongodb, return collection instance."
        return Connection(host=self.MONGO_HOST, port=self.MONGO_PORT)[self.DBNAME]

    def _connect_to_collection(self):
        "Connect to mongodb, return collection instance."
        return self._connect_to_db()[self.COLLNAME]

    def _event(self, count, name="test.event"):
        "Generate an event dictionary."
        return {'ts': time.time(),
                'count' : count,
                'event' : name }

    def _check_all(self, events):
        """Check that exactly the events in the provided list
        are in the database.
        They do not necessarily have to be in the same order.
        """
        time.sleep(1) # wait for data to 'hit' db
        coll = self._connect_to_collection()
        makekey = lambda e: (e['event'], e['count'])
        saved_events = { }
        for e in coll.find():
            key = makekey(e)
            saved_events[key] = 1
        for e in events:
            key = makekey(e)
            self.failUnless(key in saved_events,
                            "%s not found in db" % e)
            del saved_events[key]
        if saved_events:
            self.fail("%d extra events in db" % len(saved_events))

# Tests
# -----

    def testDefaults(self):
        """Default options.
        """
        analyzer = self._create()
        # load
        events = [ ]
        for i in xrange(100):
            e = self._event(i)
            analyzer.process(e)
            events.append(e)
        analyzer.finish()
        # check
        self._check_all(events)

    def testEventFilter(self):
        """Option 'event_filter'
        """
        # list of events that should end up in DB (not filtered)
        events = [ ]
        # half-pass
        for expr_num, expr in enumerate(('half.*', '.*\-pass', '.*alf')):
            analyzer = self._create(event_filter = expr)
            for i in xrange(100):
                e = self._event(i, "half-pass-%d" % expr_num)
                analyzer.process(e)
                events.append(e)
                # add these, but don't expect them
                e2 = self._event(i)
                e2['event'] = 'ignore.me'
                analyzer.process(e2)
        analyzer.finish()
        # all-pass
        for expr_num, expr in enumerate(("", ".*")):
            analyzer = self._create(event_filter = expr)
            for i in xrange(100):
                e = self._event(i, "all-pass-%d" % expr_num)
                analyzer.process(e)
                events.append(e)
        analyzer.finish()
        # no-pass
        for expr in "foobar", "\s\w*", "^$":
            analyzer = self._create(event_filter = expr)
            for i in xrange(100):
                e = self._event(i, "no-pass")
                analyzer.process(e)
        analyzer.finish()
        # check
        self._check_all(events)

# Boilerplate to run the tests
def suite():
    return shared.suite(TestCase)
if __name__ == '__main__':
    shared.main()


