from wdmmg import model
from wdmmg.lib import loader

test_data = {
    (u'red', u'flowery'): 30.0,
    (u'red', u'pungent'): 30.0,
    (u'red', None): 30.0,
    (u'green', u'flowery'): 30.0,
    (u'green', u'pungent'): 30.0,
}

class TestValueCache(object):
    @classmethod
    def setup_class(self):
        pass

    @classmethod
    def teardown_class(self):
        # Clean up.
        model.repo.delete_all()
        model.Session.remove()

    def test_value_cache(self):
        print 'test_value_cache'
        key_colour = model.Key(name=u'colour')
        model.Session.add(key_colour)
        model.Session.commit()
        cache_colour = loader.ValueCache(key_colour) 
        model.Session.commit()
        model.Session.remove()
        ev_id = cache_colour.get_value_id(u'red') 
        print ev_id
        ev = model.Session.query(model.EnumerationValue).get(ev_id)
        print ev, ev.id, ev.key_id, ev.key
        model.Session.commit()
        model.Session.remove()
        assert ev_id == cache_colour.get_value_id(u'red')
        assert ev_id != cache_colour.get_value_id(u'green')
        model.Session.commit()
        model.Session.remove()

class TestLoader(object):
    @classmethod
    def setup_class(self):
        print 'setting up TestLoader'
        key_colour = model.Key(name=u'colour')
        key_smell = model.Key(name=u'smell')
        model.Session.add_all([key_colour, key_smell])
        model.Session.commit()
        # Load an example data set.
        cache_colour = loader.ValueCache(key_colour)
        cache_smell = loader.ValueCache(key_smell)
        my_loader = loader.Loader(u'test', [key_colour, key_smell], commit_every=2)
        for (colour, smell), amount in test_data.items():
            my_loader.create_entry(amount, [
                cache_colour.get_value_id(colour),
                cache_smell.get_value_id(smell),
            ])
        model.Session.commit()
        model.Session.remove()

    @classmethod
    def teardown_class(self):
        # Clean up.
        model.repo.delete_all()
        model.Session.remove()

    def test_loader(self):
        print 'test_loader'
        # Read it back and see what we've got.
        entries = (model.Session.query(model.Entry)
#            .filter(dataset_=u'test')
            ).all()
        assert len(entries) == 5, len(entries)
        pairs = set()
        for t in entries:
            cis = (model.Session.query(model.ClassificationItem)
                .filter_by(entry=t)
                ).all()
            cis_as_dict = dict([(ci.value.key.name, ci.value.code) for ci in cis])
            pairs.add((cis_as_dict.get('colour'), cis_as_dict.get('smell')))
        assert pairs == set(test_data.keys())

