"""
Author:      www.tropofy.com

Copyright 2013 Tropofy Pty Ltd, all rights reserved.

This source file is part of Tropofy and governed by the Tropofy terms of service
available at: http://www.tropofy.com/terms_of_service.html

This source file is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE. See the license files for details.
"""

from datetime import datetime
from collections import namedtuple
from sqlalchemy.types import Float, DateTime
from sqlalchemy.schema import Column
from sqlalchemy.sql.expression import desc
from tropofy.database.tropofy_orm import DataSetMixin
from tropofy.app import AppWithDataSets, Step, StepGroup
from tropofy.widgets import SimpleGrid, ExecuteFunction, Chart
import rpy2.robjects as robjects


class Data(DataSetMixin):
    x = Column(Float, nullable=False)
    y = Column(Float, nullable=False)

    @classmethod
    def get_all(cls, data_set):
        return data_set.query(cls).all()


class RegressionResult(DataSetMixin):
    solved_datetime = Column(DateTime)
    gradiant = Column(Float, nullable=False)
    yintercept = Column(Float, nullable=False)

    def __init__(self, r_lm_result):
        self.solved_datetime = datetime.now()
        coefficients = list(r_lm_result.rx('coefficients')[0])
        self.gradiant = coefficients[1]
        self.yintercept = coefficients[0]

    @classmethod
    def get_latest_result(cls, data_set):
        return data_set.query(cls).order_by(desc(cls.solved_datetime)).first()

    def get_x_from_y(self, y):
        return (y - self.yintercept)/self.gradiant

    def get_y_from_x(self, x):
        return self.gradiant*x + self.yintercept

    def get_intercept_points_with_view_window(self, view_window):
        """Returns the two Points(x, y) of the trend line that intercept with the view window."""
        vw = view_window  # shorthand
        Point = namedtuple('Point', 'x, y')

        points = []
        points.append(Point(self.get_x_from_y(vw.ymin), vw.ymin))  # bottom
        points.append(Point(self.get_x_from_y(vw.ymax), vw.ymax))  # top
        points.append(Point(vw.xmin, self.get_y_from_x(vw.xmin)))  # left
        points.append(Point(vw.xmax, self.get_y_from_x(vw.xmax)))  # right

        intercept_points = set()  # set in case trend line goes through corners of viewwindow.

        for point in points:
            if vw.xmin <= point.x <= vw.xmax and vw.ymin <= point.y <= vw.ymax:
                intercept_points.add(point)

        return intercept_points

    def get_formatted_equation(self):
        return "%0.2f" % self.gradiant + 'x + ' + "%0.2f" % self.yintercept


class SimpleLinearRegressionWithR(AppWithDataSets):
    def get_name(self):
        return "Simple Linear Regression With R"

    def get_gui(self):
        return [
            StepGroup(name='Input', steps=[Step(name='Data', widgets=[NonEditableSimpleGrid(Data)])]),
            StepGroup(name='Calculate', steps=[Step(name='Calculate', widgets=[ExecuteSolverFunction()])]),
            StepGroup(name='Output', steps=[Step(
                name='Regression',
                widgets=[RegressionResultChart()],
                help_text='Output of simple linear regression. Coefficients calculated with R through rpy2 interface.'
            )]),
        ]

    def get_examples(self):
        return {"Demo Data set": load_example_1}

    def get_icon_url(self):
        return 'http://www.tropofy.com/static/css/img/tropofy_example_app_icons/linear_regression.png'

    def get_home_page_content(self):
        return {
            'content_app_name_header': '''
            <div>
            <span style="vertical-align: middle;">Simple Linear Regression With R</span>
            <img src="http://www.tropofy.com/static/css/img/tropofy_example_app_icons/linear_regression.png" alt="main logo" style="width:15%">
            </div>''',

            'content_single_column_app_description': '''
            <p><i>This app was created using the Tropofy platform and serves as a worked example for Tropofy problem solvers.
            The code that runs this app with the Tropofy Platform can be viewed at <a href="http://www.tropofy.com/docs/examples/simple_linear_regression_with_r.html" target="_blank">
            http://www.tropofy.com/docs/examples/simple_linear_regression_with_r.html</a>.
            This app is an example app and as such is under heavy usage. Your data may not persist between logins if we decide to clean our example apps database.</i></p><br>

            <p>This app demonstrates integrating with <a href="http://www.r-project.org/" target="_blank">R</a> to perform a simple linear regression</p>
            ''',

            'content_row_4_col_1_content': '''
            This app was created using the <a href="http://www.tropofy.com" target="_blank">Tropofy platform</a> and serves as a worked example for Tropofy problem solvers.
            '''
        }


class NonEditableSimpleGrid(SimpleGrid):
    def grid_is_editable(self):
        return False


class RegressionResultChart(Chart):
    def get_chart_type(self, data_set):
        return Chart.SCATTERCHART

    def get_table_schema(self, data_set):
        return {
            "x": ("number", "x"),
            "y1": ("number", "y1"),  # series 1: Scatter of points
            "y2": ("number", "y2"),  # series 2: Trend line
        }

    def get_table_data(self, data_set):
        data = Data.get_all(data_set)
        table_data = []

        # Data points on scatterplot
        for row in data:
            table_data.append({'x': row.x, 'y1': row.y, 'y2': None})
        result = RegressionResult.get_latest_result(data_set)

        # Trend line
        trend_points = result.get_intercept_points_with_view_window(self._get_view_window(data_set))
        for point in trend_points:
            table_data.append({'x': point.x, 'y1': None, 'y2': point.y})
        return table_data

    def get_column_ordering(self, data_set):
        return ["x", "y1", "y2"]

    def get_chart_options(self, data_set):
        result = RegressionResult.get_latest_result(data_set)
        equation = result.get_formatted_equation()
        view_window = self._get_view_window(data_set)

        return {
            'title': equation,
            'hAxis': {
                'title': 'x',
                'viewWindow': {
                    'min': view_window.xmin,
                    'max': view_window.xmax,
                }
            },
            'vAxis': {
                'title': 'y',
                'viewWindow': {
                    'min': view_window.ymin,
                    'max': view_window.ymax,
                }
            },
            'legend': 'none',
            'lineWidth': 0,
            'series': {1: {
                'color': 'red',
                'lineWidth': 1,
                'pointSize': 0,
            }}
        }

    def _get_view_window(self, data_set):
        """Returns a ViewWindow namedtuple with the bounds of the chart view window."""
        ViewWindow = namedtuple('ViewWindow', 'xmin, xmax, ymin, ymax')
        x_values = [row.x for row in Data.get_all(data_set)]
        y_values = [row.y for row in Data.get_all(data_set)]

        x_range = max(x_values) - min(x_values)
        y_range = max(y_values) - min(y_values)
        return ViewWindow(
            int(min(x_values) - 0.25*x_range),
            int(max(x_values) + 0.25*x_range),
            int(min(y_values) - 0.25*y_range),
            int(max(y_values) + 0.25*y_range),
        )


class ExecuteSolverFunction(ExecuteFunction):
    def get_button_text(self):
        return "Calculate Linear Regression"

    def execute_function(self, data_set):
        """Interface to R with rpy2 to do a simple linear regression."""
        lm = robjects.r['lm']
        data = Data.get_all(data_set)
        x = robjects.vectors.FloatVector([row.x for row in data])
        y = robjects.vectors.FloatVector([row.y for row in data])

        fmla = robjects.Formula('y ~ x')
        env = fmla.environment
        env['x'] = x
        env['y'] = y

        result = RegressionResult(lm(fmla))
        data_set.add(result)
        data_set.send_progress_message("Simple linear regression completed. Equation found: " + result.get_formatted_equation())


def load_example_1(data_set):
    data_set.add_all([
        Data(x=15, y=7),
        Data(x=18, y=3),
        Data(x=25, y=-2),
        Data(x=15.6, y=9),
        Data(x=17, y=6.3),
        Data(x=18, y=9.5),
        Data(x=23, y=0),
        Data(x=23, y=1.5),
        Data(x=23.5, y=1.25),
        Data(x=22, y=5),
        Data(x=20, y=7.2),
        Data(x=16, y=8.7),
        Data(x=19, y=8.7),
        Data(x=20, y=4.5)
    ])
