from datetime import datetime

import wdmmg.model as model

class TestORM(object):
    @classmethod
    def setup_class(self):
        self.dataset_ = u'test'
        self.accsrc = u'acc1'
        self.accdst = u'acc2'
        self.amount = 47865
        dataset_ = model.Dataset(name=self.dataset_)
        key_account = model.Key(name=u'account')
        acc_src = model.EnumerationValue(key=key_account, code=self.accsrc)
        acc_dst = model.EnumerationValue(key=key_account, code=self.accdst)
        txn = model.Entry(
            dataset_=dataset_,
            amount=self.amount)
        # TODO: Create some ClassificationItems.
        model.Session.add_all([dataset_, acc_src, acc_dst, txn])
        model.Session.commit()
        model.Session.remove()

    @classmethod
    def teardown_class(self):
        model.repo.delete_all()

    def test_01(self):
        dataset_ = model.Session.query(model.Dataset).filter_by(name=self.dataset_).one()
        txn = (model.Session.query(model.Entry)
            .filter_by(dataset_=dataset_)
            .filter_by(amount=self.amount)).one()
        assert txn
        # TODO: Test ClassificationItems can be retrieved.

    # TODO: Factor this into multiple tests?
    def test_02(self):
        dataset_ = model.Session.query(model.Dataset).filter_by(name=self.dataset_).one()
        acc_src = (model.Session.query(model.EnumerationValue)
            .filter_by(code=self.accsrc)).one()
        acc_dst = (model.Session.query(model.EnumerationValue)
            .filter_by(code=self.accdst)).one()
        region = model.Key(name=u'region', notes=u'Area for which money was spent')
        pog = model.Key(name=u'pog', notes=u'Programme Object Group')
        randomkey = model.Key(name=u'randomkey')

        northwest = model.EnumerationValue(code=u'Northwest', name=u'North west', key=region)
        northeast = model.EnumerationValue(code=u'Northeast', name=u'North east', key=region)
        pog1 = model.EnumerationValue(code=u'surestart', name=u'Sure start', key=pog)
        pog2 = model.EnumerationValue(code=u'surestart2', name=u'Another start', key=pog)

        kv1 = model.KeyValue(ns=u'enumeration_value', ns_enumerationvalue=acc_src, key=region,
                value=northwest.code)
        acc_src.keyvalues[region] = northeast.code # This should overwrite the KeyValue explicitly constructed above.
        acc_src.keyvalues[randomkey]= u'annakarenina' # This should create a new KeyValue.

        kv2 = model.KeyValue(ns=u'enumeration_value', ns_enumerationvalue=acc_dst, key=randomkey,
                value=u'orangesarenottheonlyfruit') # This one should not get overwritten.
        acc_dst.keyvalues[region] = northwest.code # This should create a new KeyValue.

        model.Session.add_all([region, pog, randomkey, kv1, kv2])
        model.Session.commit()
        model.Session.remove()
        del acc_src, acc_dst, region, pog, randomkey, kv1, kv2
        
        # Read it all back again.

        pog = model.Session.query(model.Key).filter_by(name=u'pog').one()
        assert pog.notes.startswith(u'Programme')
        assert len(pog.enumeration_values) == 2, pog

        region = model.Session.query(model.Key).filter_by(name=u'region').one()
        assert region
        region_kvs = model.Session.query(model.KeyValue).filter_by(key=region).all()
        assert len(region_kvs) == 2, region_kvs
        
        randomkey = model.Session.query(model.Key).filter_by(name=u'randomkey').one()
        assert randomkey

        acc_src = model.Session.query(model.EnumerationValue).filter_by(code=self.accsrc).one()
        assert acc_src
        acc_dst = model.Session.query(model.EnumerationValue).filter_by(code=self.accdst).one()
        assert acc_dst
        
        northeast = model.Session.query(model.EnumerationValue).filter_by(key=region).filter_by(code=u'Northeast').one()
        assert northeast.key == region, northest.key
        assert northeast.code == u'Northeast', northeast.code
        
        acc_src_region_kv = (model.Session.query(model.KeyValue)
            .filter_by(ns_enumerationvalue=acc_src)
            .filter_by(key=region)
            ).one()
        assert acc_src_region_kv.value == u'Northeast', acc_src_region_kv.value
        assert acc_src_region_kv.enumeration_value == northeast, acc_src_region_kv.enumeration_value
        assert acc_src._keyvalues[region] == acc_src_region_kv
        assert acc_src.keyvalues[region] == u'Northeast'
        assert acc_src.keyvalues[randomkey] == u'annakarenina'
        
        assert acc_dst.keyvalues[region] == u'Northwest'
        assert acc_dst.keyvalues[randomkey] == u'orangesarenottheonlyfruit'
        
        # Test cascading deletion of KeyValues when removed from objects.
        print [str(x) for x in model.Session.query(model.KeyValue).all()]
        before_count = model.Session.query(model.KeyValue).count()
        acc_src.keyvalues.clear()
        model.Session.commit()
        model.Session.remove()
        
        acc_src = (model.Session.query(model.EnumerationValue)
            .filter_by(code=self.accsrc)
            ).one()
        assert not acc_src.keyvalues
        print [str(x) for x in model.Session.query(model.KeyValue).all()]
        after_count = model.Session.query(model.KeyValue).count()
        assert before_count > after_count, (before_count, after_count)
        
        # Test cascading deletion of EnumerationValues when Keys are deleted.
        print [str(x) for x in model.Session.query(model.EnumerationValue).all()]
        before_count = model.Session.query(model.EnumerationValue).count()
        region = model.Session.query(model.Key).filter_by(name=u'region').one()
        print 'region =', region
        model.Session.delete(region)
        model.Session.commit()
        model.Session.remove()

        print [str(x) for x in model.Session.query(model.EnumerationValue).all()]
        after_count = model.Session.query(model.EnumerationValue).count()
        assert before_count > after_count, (before_count, after_count)
        
        # Test cascading deletion of KeyValues when Keys are deleted.
        print [str(x) for x in model.Session.query(model.KeyValue).all()]
        before_count = model.Session.query(model.KeyValue).count()
        randomkey = model.Session.query(model.Key).filter_by(name=u'randomkey').one()
        print 'randomkey =', randomkey
        model.Session.delete(randomkey)
        model.Session.commit()
        model.Session.remove()

        print [str(x) for x in model.Session.query(model.KeyValue).all()]
        after_count = model.Session.query(model.KeyValue).count()
        assert before_count > after_count, (before_count, after_count)

