from sqlalchemy.types import Text, Float, Integer
from sqlalchemy.schema import Column, ForeignKeyConstraint, UniqueConstraint
from tropofy.database.tropofy_orm import DataSetMixin
from tropofy.app import AppWithDataSets, Step, StepGroup
from tropofy.widgets import SimpleGrid, KMLMap, Chart
from simplekml import Kml


class Store(DataSetMixin):
    name = Column(Text, nullable=False)
    latitude = Column(Float, nullable=False)
    longitude = Column(Float, nullable=False)

    __table_args__ = (UniqueConstraint('name', 'data_set_id'),)


class Performance(DataSetMixin):
    store_name = Column(Text, nullable=False)
    year = Column(Integer, nullable=False)
    sales = Column(Float, nullable=False, default=0)
    expenses = Column(Float, nullable=False, default=0)

    __table_args__ = (
        UniqueConstraint('store_name', 'year', 'data_set_id'),
        ForeignKeyConstraint(['store_name', 'data_set_id'], ['store.name', 'store.data_set_id'], ondelete='CASCADE', onupdate='CASCADE'),
    )


class MyKMLMap(KMLMap):
    def get_kml(self, data_set):
        kml = Kml()
        for store in data_set.query(Store).all():
            kml.newpoint(name=store.name, coords=[(store.longitude, store.latitude)])
        return kml.kml()


class PerformanceBarChart(Chart):
    def get_chart_type(self, data_set):
        return Chart.BARCHART

    def get_table_schema(self, data_set):
        return {
            "year": ("string", "Year"),
            "sales": ("number", "Sales"),
            "expenses": ("number", "Expenses")
        }

    def get_table_data(self, data_set):
        results = []
        years = [y for r in data_set.query(Performance.year).distinct() for y in r]
        for year in years:
            performances = data_set.query(Performance).filter_by(year=year).all()
            results.append({
                "year": year,
                "sales": sum(p.sales for p in performances),
                "expenses": sum(p.expenses for p in performances),
            })
        return results

    def get_column_ordering(self, data_set):
        return ["year", "sales", "expenses"]

    def get_order_by_column(self, data_set):
        return "year"

    def get_chart_options(self, data_set):
        return {
            'title': 'Company Performance',
            'vAxis': {
                'title': 'Year',
                'titleTextStyle': {'color': 'red'}
            }
        }


class StoreExpensesPieChart(Chart):
    def get_chart_type(self, data_set):
        return Chart.PIECHART

    def get_table_schema(self, data_set):
        return {
            "store": ("string", "Store"),
            "expenses": ("number", "Expenses")
        }

    def get_table_data(self, data_set):
        results = []
        for store in data_set.query(Store).all():
            performances = data_set.query(Performance).filter_by(store_name=store.name).all()
            results.append({
                "store": store.name,
                "expenses": sum(p.expenses for p in performances),
            })
        return results

    def get_column_ordering(self, data_set):
        return ["store", "expenses"]

    def get_chart_options(self, data_set):
        total_expense = sum(p.expenses for p in data_set.query(Performance).all())
        return {
            'title': 'Company Expenses: Total = ${expense}'.format(expense=str(total_expense)),
        }


class MyFirstApp(AppWithDataSets):
    def get_name(self):
        return "My First App"

    def get_gui(self):
        step_group_1 = StepGroup(name='Input')
        step_group_1.add_step(Step(name='Stores', widgets=[SimpleGrid(Store)]))
        step_group_1.add_step(Step(name='Performances', widgets=[SimpleGrid(Performance)]))

        step_group_2 = StepGroup(name='Output')
        step_group_2.add_step(Step(
            name='Visualisations',
            widgets=[
                {"widget": PerformanceBarChart(), "cols": 6},
                {"widget": StoreExpensesPieChart(), "cols": 6},
                {"widget": MyKMLMap(), "cols": 12},
            ],
        ))

        return [step_group_1, step_group_2]

    def get_examples(self):
        return {"Demo data for Brisbane North": load_example_data}


def load_example_data(data_set):
    stores = []
    stores.append(Store(name="CLAYFIELD", latitude=-27.417536, longitude=153.056677))
    stores.append(Store(name="SANDGATE", latitude=-27.321538, longitude=153.069267))
    data_set.add_all(stores)

    performances = []
    performances.append(Performance(store_name="CLAYFIELD", year=2011, sales=1000, expenses=400))
    performances.append(Performance(store_name="CLAYFIELD", year=2012, sales=1170, expenses=460))
    performances.append(Performance(store_name="SANDGATE", year=2011, sales=660, expenses=1120))
    performances.append(Performance(store_name="SANDGATE", year=2012, sales=1030, expenses=540))
    data_set.add_all(performances)
