from datetime import date
import os, sys, csv
import pkg_resources

import wdmmg.model as model
import wdmmg.lib.aggregator as aggregator
from wdmmg.tests import Fixtures

class TestAggregator(object):
    @classmethod
    def setup_class(self):
        Fixtures.setup()
    
    @classmethod
    def teardown_class(self):
        Fixtures.teardown()
    
    def test_aggregate(self):
        ans = aggregator.aggregate(
            Fixtures.dataset_,
            include=[(Fixtures.region, u'ENGLAND_South West')],
            axes=[
                Fixtures.cap_or_cur, Fixtures.cofog1,
                # Omit Fixtures.pog, Fixtures.region,
            ])
        print ans
        assert ans.dates == [u'2003', u'2004', u'2005',
            u'2006', u'2007', u'2008', u'2009',
            u'2010'], ans.dates
        assert ans.axes == [u'cap_or_cur', u'cofog1'], ans.axes
        index = dict([(coords, sum(amount)) for (coords, amount) in ans.matrix.items()])
        for k, v in index.items():
            print k, v
        assert len(index) == 3, index
        for amount, coords in [
            (70700000.0, (u'CUR', u'10')),
            (500000.0, (u'CAP', u'03')),
            (-608900000.0, (u'CAP', u'06')),
        ]:
            assert index.has_key(coords), coords
            # Tolerate rounding errors.
            assert abs(index[coords] - amount) < 1, (coords, amount)

    def test_make_aggregate_query(self):
        query, params = aggregator._make_aggregate_query(
            Fixtures.dataset_,
            include=[(Fixtures.region, u'ENGLAND_South West')],
            axes=[Fixtures.cap_or_cur, Fixtures.cofog1],
        )
        print query
        print params
        assert query == '''\
SELECT
    (SELECT ev.code FROM classification_item ci, enumeration_value ev
        WHERE ev.key_id = :ak_0
        AND ci.entry_id = t.id AND ci.value_id = ev.id) AS axis_0,
    (SELECT ev.code FROM classification_item ci, enumeration_value ev
        WHERE ev.key_id = :ak_1
        AND ci.entry_id = t.id AND ci.value_id = ev.id) AS axis_1,
    SUM(t.amount) as amount,
    (SELECT ev.code FROM classification_item ci, enumeration_value ev
        WHERE ev.key_id = :key_time_id
        AND ci.entry_id = t.id AND ci.value_id = ev.id) AS time
FROM "entry" t
WHERE t.dataset_id = :dataset_id
AND t.id IN (SELECT ci.entry_id FROM classification_item ci, enumeration_value ev
    WHERE ev.key_id = :k_0
    AND ev.code = :v_0 AND ci.value_id = ev.id)
GROUP BY time, axis_0, axis_1'''

    def test_aggregate_per(self):
        ans = aggregator.aggregate(
            Fixtures.dataset_,
            axes=[Fixtures.cofog1, Fixtures.region],
        )
        print ans
        key_population = (model.Session.query(model.Key)
            .filter_by(name=u'population2006')
            ).one()
        ans.divide_by_statistic(Fixtures.region, key_population)
        print ans
        assert ans.axes == [u'cofog1', u'region'], ans.axes
        index = dict([(coords, amount[-2]) for (coords, amount) in ans.matrix.items()])
        for k, v in index.items():
            print k, v
        assert len(index) == 7, index
        for amount, coords in [
            (2.365, (u'04', u'SCOTLAND')),
            (5.795, (u'10', u'ENGLAND_West Midlands')),
            (0.106, (u'04', u'ENGLAND_London')),
            (3.083, (u'10', u'ENGLAND_South West')),
            (0.020, (u'03', u'ENGLAND_South West')),
            (4.356, (u'04', u'ENGLAND_Yorkshire and The Humber')),
            (-4.879, (u'06', u'ENGLAND_South West')),
        ]:
            assert index.has_key(coords), coords
            # Tolerate rounding errors.
            assert abs(index[coords] - amount) < 1e-3, (coords, amount)

    def test_aggregate_per_time(self):
        ans = aggregator.aggregate(
            Fixtures.dataset_,
            axes=[],
        )
        print ans
        ans.divide_by_time_statistic('gdp_deflator2006')
        print ans
        data = ans.matrix[()]
        assert len(data) == 8, data
        for i, amount in enumerate([
            -180.8, -122.5, -88.0, -82.4, -19.6, 21.9, 52.4, 18.4]):
            assert abs(data[i] - amount*1e6) < 1e5, (i, amount)

# TODO: Test filtering on dataset.
# TODO: Test with some breakdown KeyValues missing (i.e. coordinate is NULL).
# TODO: Test per without breakdown.

