"""
Authors: www.tropofy.com

Copyright 2013 Tropofy Pty Ltd, all rights reserved.

This source file is part of Tropofy and governed by the Tropofy terms of service
available at: http://www.tropofy.com/terms_of_service.html

The LocalSolver example this app is based on can be found at
http://www.localsolver.com/exampletour.html?file=maxcut.zip

Used with permission.

This source file is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE. See the license files for details.


Problem Description:
"MAX CUT

Given a graph G=(V,E), a cut is a partition of V into two subset S and V-S. The size
of a cut is the number of edges with one extremity in S and the other in V-S. The
Max Weighted Cut problem is to maximise the total weight of the edges in the cut.

Furthermore, a MAX WEIGHTED CUT is implemented, if all weights are not == 1.
In this version each edge has a real number, its weight, and the objective is to maximize
not the number of edges but the .
[http://www.localsolver.com/exampletour.html?file=maxcut.zip]
"""

import subprocess
from sqlalchemy.types import Text, Float, Boolean, Integer
from sqlalchemy.schema import Column, UniqueConstraint, ForeignKeyConstraint
from sqlalchemy.orm import relationship
from simplekml import Kml, Style, IconStyle, Icon, LineStyle

from tropofy.app import AppWithDataSets, Step, StepGroup
from tropofy.widgets import ExecuteFunction, SimpleGrid, KMLMap, Chart
from tropofy.database.tropofy_orm import DataSetMixin
from tropofy.file_io import read_write_xl
import os


class MaxWCutNode(DataSetMixin):
    name = Column(Text, nullable=False)
    latitude = Column(Float, nullable=False)
    longitude = Column(Float, nullable=False)
    allocated_group_index = Column(Boolean)  # Only 2 possible groups, thus boolean.

    @classmethod
    def get_table_args(cls): 
        return (UniqueConstraint('data_set_id', 'name'),)

    @classmethod
    def get_ordered_list_of_all_nodes(cls, data_set):
        return data_set.query(cls).order_by(cls.id).all()


class MaxWCutArc(DataSetMixin):
    orig_name = Column(Text, nullable=False)
    dest_name = Column(Text, nullable=False)
    weight = Column(Integer, nullable=False)  # Must be int for Localsolver

    # The primaryjoin argument to relationship is only needed when there is ambiguity
    orig_node = relationship(MaxWCutNode, primaryjoin="and_(MaxWCutArc.data_set_id==MaxWCutNode.data_set_id, MaxWCutArc.orig_name==MaxWCutNode.name)")
    dest_node = relationship(MaxWCutNode, primaryjoin="and_(MaxWCutArc.data_set_id==MaxWCutNode.data_set_id, MaxWCutArc.dest_name==MaxWCutNode.name)")

    @classmethod
    def get_table_args(cls):
        return (
            UniqueConstraint('data_set_id', 'orig_name', 'dest_name'),
            ForeignKeyConstraint(['orig_name', 'data_set_id'], ['maxwcutnode.name', 'maxwcutnode.data_set_id'], ondelete='CASCADE', onupdate='CASCADE'),
            ForeignKeyConstraint(['dest_name', 'data_set_id'], ['maxwcutnode.name', 'maxwcutnode.data_set_id'], ondelete='CASCADE', onupdate='CASCADE')
        )

    @property
    def in_cut(self):
        if self.orig_node.allocated_group_index is not None and self.dest_node.allocated_group_index is not None:
            return self.orig_node.allocated_group_index != self.dest_node.allocated_group_index

    @property
    def html_description(self):
        return """{orig}, {dest}<br>Weight: {weight}""".format(
            orig=self.orig_name,
            dest=self.dest_name,
            weight=self.weight
        )


class MapArcInput(KMLMap):
    def get_kml(self, data_set):
        kml = Kml()

        node_style = Style(iconstyle=IconStyle(scale=0.8, icon=Icon(href='http://maps.google.com/mapfiles/kml/paddle/blu-circle-lv.png')))
        node_folder = kml.newfolder(name="Nodes")
        for p in [node_folder.newpoint(name=n.name, coords=[(n.longitude, n.latitude)]) for n in data_set.query(MaxWCutNode).all()]:
            p.style = node_style

        arc_folder = kml.newfolder(name="Arcs")
        arcs = data_set.query(MaxWCutArc).all()
        for arc in arcs:
            arc_style = Style(linestyle=LineStyle(color='FFC86602', width=4))
            l = arc_folder.newlinestring(name=arc.html_description, coords=[(arc.orig_node.longitude, arc.orig_node.latitude), (arc.dest_node.longitude, arc.dest_node.latitude)])
            l.style = arc_style

        return kml.kml()


class MapArcOutput(KMLMap):
    def get_kml(self, data_set):
        kml = Kml()

        group_one_node_style = Style(iconstyle=IconStyle(scale=0.8, icon=Icon(href='http://maps.google.com/mapfiles/kml/paddle/blu-circle-lv.png')))
        group_one_node_folder = kml.newfolder(name="Nodes - Group 1")
        for p in [group_one_node_folder.newpoint(name=n.name, coords=[(n.longitude, n.latitude)]) for n in data_set.query(MaxWCutNode).filter_by(allocated_group_index=True).all()]:
            p.style = group_one_node_style

        group_two_node_style = Style(iconstyle=IconStyle(scale=0.4, icon=Icon(href='http://maps.google.com/mapfiles/kml/paddle/red-circle-lv.png')))
        group_two_node_folder = kml.newfolder(name="Nodes - Group 2")
        for p in [group_two_node_folder.newpoint(name=n.name, coords=[(n.longitude, n.latitude)]) for n in data_set.query(MaxWCutNode).filter_by(allocated_group_index=False).all()]:
            p.style = group_two_node_style

        arc_in_cut_style = Style(linestyle=LineStyle(color='FF0101F9', width=4))
        arc_not_in_cut_style = Style(linestyle=LineStyle(color='FFC86602', width=3))

        arcs_in_cut_folder = kml.newfolder(name="Arcs in Cut")
        arcs_not_in_cut_folder = kml.newfolder(name="Arcs not in Cut")

        arcs = data_set.query(MaxWCutArc).all()

        for arc in arcs:
            if arc.in_cut:
                l = arcs_in_cut_folder.newlinestring(name=arc.html_description, coords=[(arc.orig_node.longitude, arc.orig_node.latitude), (arc.dest_node.longitude, arc.dest_node.latitude)])
                l.style = arc_in_cut_style
            else:
                l = arcs_not_in_cut_folder.newlinestring(name=arc.html_description, coords=[(arc.orig_node.longitude, arc.orig_node.latitude), (arc.dest_node.longitude, arc.dest_node.latitude)])
                l.style = arc_not_in_cut_style
        return kml.kml()


class WeightInCutPieChart(Chart):
    def get_chart_type(self, data_set):
        return Chart.PIECHART

    def get_table_schema(self, data_set):
        return {"item": ("string", "Item"), "weight": ("number", "Weight")}

    def get_table_data(self, data_set):
        arcs = data_set.query(MaxWCutArc).all()
        arc_weight_total = sum((arc.weight for arc in arcs))
        arc_weight_in_cut = sum((arc.weight for arc in arcs if arc.in_cut))

        items = [
            {"item": "In Cut", "weight": arc_weight_in_cut},
            {"item": "Not In Cut", "weight": arc_weight_total - arc_weight_in_cut},
        ]
        return items

    def get_column_ordering(self, data_set):
        return ["item", "weight"]

    def get_order_by_column(self, data_set):
        return "weight"

    def get_chart_options(self, data_set):
        arcs = data_set.query(MaxWCutArc).all()
        arc_weight_total = sum((arc.weight for arc in arcs))
        arc_weight_in_cut = sum((arc.weight for arc in arcs if arc.in_cut))
        return {'title': 'Weight in Cut: {percent} ({weight} / {max})'.format(
            percent="{0:.0f}%".format(float(arc_weight_in_cut)/arc_weight_total * 100),
            weight=arc_weight_in_cut,
            max=arc_weight_total,
        )}


class ExecuteLocalSolver(ExecuteFunction):
    def get_button_text(self):
        return "Solve Max Weighted Cut"

    def execute_function(self, data_set):
        call_local_solver(data_set)


class LocalSolverMaxWeightedCut(AppWithDataSets):
    def get_name(self):
        return 'LocalSolver Max Weighted Cut'

    def get_gui(self):
        step_group1 = StepGroup(name='Input')
        step_group1.add_step(Step(
            name='Nodes',
            widgets=[SimpleGrid(MaxWCutNode)],
        ))
        step_group1.add_step(Step(
            name='Arcs',
            widgets=[SimpleGrid(MaxWCutArc)],
            help_text="Enter arcs connecting locations. Note that the direction of each arc does not matter. Furthermore, to solve Max Cut instead of Max Weighted Cut, enter all weights as 1.",
        ))
        step_group1.add_step(Step(name='Review Inputs', widgets=[MapArcInput()]))

        step_group2 = StepGroup(name='Solve')
        step_group2.add_step(Step(name='Solve Max Weighted cut using LocalSolver', widgets=[ExecuteLocalSolver()]))

        step_group3 = StepGroup(name='Output')
        step_group3.add_step(Step(name='Max Cut Solution', widgets=[
            {"widget": MapArcOutput(), "cols": 12},
            {"widget": WeightInCutPieChart(), "cols": 5},
            {"widget": SimpleGrid(MaxWCutNode), "cols": 7},
        ]))

        return [step_group1, step_group2, step_group3]

    def get_examples(self):
        return {
            "Demo European data": load_euro_data,
            "Demo Brisbane data": load_brisbane_data,
        }

    def get_default_example_data_set_name(self):
        return "Demo European data"

    def get_icon_url(self):
        return 'http://www.tropofy.com/static/css/img/tropofy_example_app_icons/max_cut.png'

    def get_home_page_content(self):
        return {
            'content_app_name_header': '''
            <div>
            <span style="vertical-align: middle;">Max Weighted Cut</span>
            <img src="http://www.tropofy.com/static/css/img/tropofy_example_app_icons/max_cut.png" alt="main logo" style="width:15%">
            </div>''',

            'content_single_column_app_description': '''
            <p>This app solves the Max Weighted Cut problem, in a geographic setting. It is defined as follows:
                <ul>
                    <li>Given a set of locations (nodes)</li>
                    <li>And a set of numerical relationships between some of these locations (weighted arcs).</li>
                    <li>Allocate the locations into two groups, such that the sum of the relationships between locations in different groups is maximised.</li>
                </ul>
                The Mac Weighted Cut problem is known to be NP-Complete.
            </p>
            <p>Need help or wish this app had more features? Contact us at <b>info@tropofy.com</b> to see if we can help.</p>''',

            'content_row_4_col_1_content': '''
            This app was created using the <a href="http://www.tropofy.com" target="_blank">Tropofy platform</a> and is powered by <a href="http://www.localsolver.com/" target="_blank">LocalSolver</a>.
            '''
        }


def call_local_solver(data_set):
    invoke_localsolver_using_lsp_file(data_set, write_localsolver_input_file(data_set))


def write_localsolver_input_file(data_set):
    input_file_name = 'input.in'
    input_file_path = data_set.get_path_of_file_in_data_set_folder(input_file_name)
    f = open(input_file_path, 'w')

    nodes = MaxWCutNode.get_ordered_list_of_all_nodes(data_set)
    node_name_to_index = {node.name: i+1 for i, node in enumerate(nodes)}  # This LocalSolver example indexes from 1
    arcs = data_set.query(MaxWCutArc).all()

    f.write('%s %s\n' % (len(nodes), len(arcs)))
    for arc in arcs:
        f.write('{orig} {dest} {weight}\n'.format(
            orig=node_name_to_index[arc.orig_name],
            dest=node_name_to_index[arc.dest_name],
            weight=arc.weight,
        ))

    f.close()
    return input_file_name


def invoke_localsolver_using_lsp_file(data_set, input_file_name):
    for item in data_set.query(MaxWCutNode).all():
        item.allocated_group_index = None  # Reset solution

    lsp_file_path = data_set.app.get_path_of_file_in_app_folder('max_cut.lsp')
    solution_file_name = 'output.txt'
    solution_file_path = data_set.get_path_of_file_in_data_set_folder(solution_file_name)
    open(solution_file_path, 'w').close()  # clear the solution file if it exists
    p = subprocess.Popen(
        ["localsolver", lsp_file_path, "inFileName=%s" % input_file_name, "solFileName=%s" % solution_file_name, "lsTimeLimit=2"], 
        cwd=data_set.file_save_folder,
        stdout=subprocess.PIPE,
    )
    out, _ = p.communicate()
    with open(solution_file_path) as f:
        content = f.readlines()
        if content:
            data_set.send_progress_message(out.replace("\n", "<br>"))

            nodes = MaxWCutNode.get_ordered_list_of_all_nodes(data_set)
            for i, node in enumerate(nodes):
                node.allocated_group_index = content[i].split(' ')[1]
        else:
            data_set.send_progress_message(
                '''The data you have entered exceeds the limits of the trial version of LocalSolver used to run this app.
                LocalSolver's Trial Version does not allow more than 1000 expressions and 100 decisions.'''
            )


def load_brisbane_data(data_set):  # Post code geocode data sourced from http://blog.orite.com.au/wp-content/uploads/2009/01/aupcgeo.7z
    read_write_xl.ExcelReader.load_data_from_excel_file_on_disk(data_set, data_set.app.get_path_of_file_in_app_folder('max_cut_brisbane_example_data.xlsx'))


def load_euro_data(data_set):  # Post code geocode data sourced from http://blog.orite.com.au/wp-content/uploads/2009/01/aupcgeo.7z
    read_write_xl.ExcelReader.load_data_from_excel_file_on_disk(data_set, data_set.app.get_path_of_file_in_app_folder('max_cut_euro_example_data.xlsx'))
