'''
Authors: www.tropofy.com and www.gurobi.com

Copyright 2013 Tropofy Pty Ltd, all rights reserved.
Copyright 2013, Gurobi Optimization, Inc.

This source file (where not indicated as under the copyright of Gurobi)
is part of Tropofy and govered by the Tropofy terms of service
available at: http://www.tropofy.com/terms_of_service.html

Parts of the formulation provided by Gurobi have been modified.
The original example is in the Gurobi installation in the example file workforce4.py

Used with permission.

This source file is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE. See the license files for details.
'''

import gurobipy # Note eclipse has problems importing quicksum so we dont import selectively
from sqlalchemy.types import Integer, Text, Float
from sqlalchemy.schema import Column, ForeignKeyConstraint, UniqueConstraint

from tropofy.app import AppWithDataSets, Step, StepGroup
from tropofy.widgets import ExecuteFunction, SimpleGrid, Chart
from tropofy.database.tropofy_orm import DataSetMixin
from tropofy.database import read_write_xl


class ShiftRequirements(DataSetMixin):
    day = Column(Text, nullable=False)
    staff = Column(Integer, nullable=False)

    def __init__(self, day, staff):
        self.day = day
        self.staff = staff

    @classmethod
    def get_table_args(cls):
        return (UniqueConstraint('data_set_id', 'day'),)


class Staff(DataSetMixin):
    name = Column(Text, nullable=False)
    rate = Column(Float, nullable=False)

    def __init__(self, name, rate):
        self.name = name
        self.rate = rate

    @classmethod
    def get_table_args(cls):
        return (UniqueConstraint('data_set_id', 'name'),)


class Availability(DataSetMixin):
    staff_name = Column(Text, nullable=False)
    day = Column(Text, nullable=False)

    def __init__(self, staff_name, day):
        self.staff_name = staff_name
        self.day = day

    @classmethod
    def get_table_args(cls):
        return (
            ForeignKeyConstraint(['staff_name', 'data_set_id'], ['staff.name', 'staff.data_set_id'], ondelete='CASCADE', onupdate='CASCADE'),
            ForeignKeyConstraint(['day', 'data_set_id'], ['shiftrequirements.day', 'shiftrequirements.data_set_id'], ondelete='CASCADE', onupdate='CASCADE')
        )


class ShiftsWorkedForLeastCost(DataSetMixin):
    staff_name = Column(Text, nullable=False)
    day = Column(Text, nullable=False)

    def __init__(self, staff_name, day):
        self.staff_name = staff_name
        self.day = day

    @classmethod
    def get_table_args(cls):
        return (
            ForeignKeyConstraint(['staff_name', 'data_set_id'], ['staff.name', 'staff.data_set_id'], ondelete='CASCADE', onupdate='CASCADE'),
            ForeignKeyConstraint(['day', 'data_set_id'], ['shiftrequirements.day', 'shiftrequirements.data_set_id'], ondelete='CASCADE', onupdate='CASCADE')
        )


class ShiftsWorkedForEvenWorkings(DataSetMixin):
    staff_name = Column(Text, nullable=False)
    day = Column(Text, nullable=False)

    def __init__(self, staff_name, day):
        self.staff_name = staff_name
        self.day = day

    @classmethod
    def get_table_args(cls):
        return (
            ForeignKeyConstraint(['staff_name', 'data_set_id'], ['staff.name', 'staff.data_set_id'], ondelete='CASCADE', onupdate='CASCADE'),
            ForeignKeyConstraint(['day', 'data_set_id'], ['shiftrequirements.day', 'shiftrequirements.data_set_id'], ondelete='CASCADE', onupdate='CASCADE')
        )


class ExecuteGurobiSolver(ExecuteFunction):

    def get_button_text(self):
        return "Solve Staff Rostering Problem"

    def execute_function(self, data_set):
        formulate_and_solve_staff_rostering_problem(data_set)


class RequirementsAndAvailabilityChart(Chart):

    def get_chart_type(self, data_set):
        return Chart.COLUMNCHART

    def get_table_schema(self, data_set):
        return {
            "day": ("string", "Day"),
            "required": ("number", "Staff Required"),
            "available": ("number", "Staff Available"),
            "shortfall": ("number", "Shortfall")
        }

    def get_table_data(self, data_set):
        return [{
            'day': r.day,
            'required': r.staff,
            'available': len(data_set.query(Availability).filter(Availability.day == r.day).all()),
            'shortfall':max(0, r.staff - len(data_set.query(Availability).filter(Availability.day == r.day).all()))
            } for r in data_set.query(ShiftRequirements).all()]

    def get_column_ordering(self, data_set):
        return ["day", "required", "available", "shortfall"]

    def get_chart_options(self, data_set):
        shortfall = sum([max(0, r.staff - len(data_set.query(Availability).filter(Availability.day == r.day).all())) for r in data_set.query(ShiftRequirements).all()])
        if shortfall > 0:
            return {'title': 'Staff Required and Available with Shortfall - Shortfall of %s' % (shortfall)}
        return {'title': 'Staff Required and Available with Shortfall'}


class AvailabilityByStaffChart(Chart):

    def get_chart_type(self, data_set):
        return Chart.COLUMNCHART

    def get_table_schema(self, data_set):
        return {"name": ("string", "Name"), "available": ("number", "Availability over Period")}

    def get_table_data(self, data_set):
        return [{
            'name': s.name,
            'available': len(data_set.query(Availability).filter(Availability.staff_name == s.name).all())
            } for s in data_set.query(Staff).all()]

    def get_column_ordering(self, data_set):
        return ["name", "available"]

    def get_chart_options(self, data_set):
        return {'title': 'Staff Availability over Period'}


class ShiftsWorkedPieChart(Chart):

    def __init__(self, type_):
        self.type_ = type_

    def get_chart_type(self, data_set):
        return Chart.PIECHART

    def get_table_schema(self, data_set):
        return {"name": ("string", "Name"), "shifts": ("number", "Shifts Worked")}

    def get_table_data(self, data_set):
        return [{
            'name': s.name,
            'shifts': len(data_set.query(self.type_).filter(self.type_.staff_name == s.name).all())
            } for s in data_set.query(Staff).all()]

    def get_column_ordering(self, data_set):
        return ["name", "shifts"]

    def get_chart_options(self, data_set):
        rates = {s.name: s.rate for s in data_set.query(Staff).all()}
        return {
            'title': 'Shifts Worked, Total Cost = %s' % (sum([rates[s.staff_name] for s in data_set.query(self.type_).all()])),
            'pieSliceText': 'value'
        }


class Roster(Chart):

    def __init__(self, type_):
        self.type_ = type_

    def get_chart_type(self, data_set):
        return Chart.TIMELINE

    def get_table_schema(self, data_set):
        return {"name": ("string", "Name"),
                "state": ("string", "Shifts Worked"),
                "start": ("number", "Start"),
                "end": ("number", "End")}

    def get_table_data(self, data_set):
        mlpd = 1000
        data = []
        days = [s.day for s in data_set.query(ShiftRequirements).all()]
        staff = [s.name for s in data_set.query(Staff).all()]
        for day in days:
            for s in staff:
                shift = data_set.query(self.type_).filter(self.type_.staff_name == s).filter(self.type_.day == day).all()
                data.append({'name': s, 'state': 'on' if shift else 'off', 'start': days.index(day) * mlpd, 'end': (days.index(day) + 1) * mlpd})
        return data

    def get_column_ordering(self, data_set):
        return ["name", "state", "start", "end"]

    def get_chart_options(self, data_set):
        return {
            'title': 'Roster',
            'tooltip.isHtml': 'true',
            'showRowLabels': False,  # no matter what I do I cannot turn off the tooltips aaarrrggghhhh!
            'trigger': 'none',
            'enableInteractivity': False,
            'tooltip': {'trigger': 'none', 'enabled': 'false', 'isHtml': 'true'},
            'series': [{'color': '#FFB82C'}, {'color': '#CB0F32'}]
        }


class GurobiRosteringApp(AppWithDataSets):

    def get_name(self):
        return 'Staff Rostering Optimiser'

    def get_examples(self):
        return {"Demo data set from Gurobi": load_gurobi_data}

    def get_gui(self):
        step_group1 = StepGroup(name='Enter your data')
        step_group1.add_step(Step(name='Enter your shift requirements', widgets=[SimpleGrid(ShiftRequirements)]))
        step_group1.add_step(Step(name='Enter your staff names and rates', widgets=[SimpleGrid(Staff)]))
        step_group1.add_step(Step(name='Enter your staff availability', widgets=[SimpleGrid(Availability)]))
        step_group1.add_step(Step(name='Review your input data', widgets=[RequirementsAndAvailabilityChart(), AvailabilityByStaffChart()]))

        step_group2 = StepGroup(name='Solve')
        step_group2.add_step(Step(name='Solve staff rostering problem using Gurobi', widgets=[ExecuteGurobiSolver()]))

        step_group3 = StepGroup(name='View the Solution')
        step_group3.add_step(Step(
            name='Shifts worked for min cost',
            widgets=[
                {"widget": SimpleGrid(ShiftsWorkedForLeastCost), "cols": 6},
                {"widget": ShiftsWorkedPieChart(ShiftsWorkedForLeastCost), "cols": 6},
                {"widget": Roster(ShiftsWorkedForLeastCost), "cols": 12},
            ]
        ))
        step_group3.add_step(Step(
            name='Shifts worked for even workings',
            widgets=[
                {"widget": SimpleGrid(ShiftsWorkedForEvenWorkings), "cols": 6},
                {"widget": ShiftsWorkedPieChart(ShiftsWorkedForEvenWorkings), "cols": 6},
                {"widget": Roster(ShiftsWorkedForEvenWorkings), "cols": 12},
            ]
        ))

        return [step_group1, step_group2, step_group3]

    def get_icon_url(self):
        return 'http://www.tropofy.com/static/css/img/tropofy_example_app_icons/staff_rostering.png'

    def get_home_page_content(self):
        return {
            'content_app_name_header': '''
            <div>
            <span style="vertical-align: middle;">Staff Rostering Optimiser</span>
            <img src="http://www.tropofy.com/static/css/img/tropofy_example_app_icons/staff_rostering.png" alt="main logo" style="width:15%">
            </div>''',

            'content_single_column_app_description': '''
            <p>Do you have to roster staff, with specific availability to cover a given workload and want the minimum cost solution?</p>
            <p>This app might be what you are looking for! Sign up and give it a go.
            <p>Need help or wish this app had more features, contact us at <b>info@tropofy.com</b> to see if we can help</p>''',

            'content_row_4_col_1_content': '''
            This app was created using the <a href="http://www.tropofy.com" target="_blank">Tropofy platform</a> and is powered by <a href="http://www.gurobi.com" target="_blank">Gurobi</a>.
            '''
        }


def load_gurobi_data(data_set):
    read_write_xl.load_data_from_excel_file_on_disk(data_set, data_set.app.get_path_of_file_in_app_folder('gurobi_rostering_data.xlsx'))


def formulate_and_solve_staff_rostering_problem(data_set):
    # Copyright 2013, Gurobi Optimization, Inc.
    # Adapted by Tropofy Pty Ltd to integrate with the Tropofy Platform

    # Assign workers to shifts; each worker may or may not be available on a
    # particular day. We use lexicographic optimization to solve the model:
    # first, we minimize the linear sum of the slacks. Then, we constrain
    # the sum of the slacks, and we minimize a quadratic objective that
    # tries to balance the workload among the workers.

    shifts = []
    shiftRequirements = {}
    for demand in data_set.query(ShiftRequirements).all():
        shifts.append(demand.day)
        shiftRequirements[demand.day] = demand.staff

    workers = []
    pay = {}
    for s in data_set.query(Staff).all():
        workers.append(s.name)
        pay[s.name] = s.rate

    sum_of_pay_for_slack_cost = sum([v for _, v in pay.items()])

    availability = []
    for a in data_set.query(Availability).all():
        availability.append((a.staff_name, a.day))
    availability = gurobipy.tuplelist(availability)

    # Model
    m = gurobipy.Model("assignment")

    # Assignment variables: x[w,s] == 1 if worker w is assigned to shift s.
    # This is no longer a pure assignment model, so we must use binary variables.
    x = {}
    for w, s in availability:
        x[w, s] = m.addVar(vtype=gurobipy.GRB.BINARY, obj=pay[w], name=w + "." + s)

    # Slack variables for each shift constraint so that the shifts can
    # be satisfied
    slacks = {}
    for s in shifts:
        slacks[s] = m.addVar(obj=sum_of_pay_for_slack_cost, name=s + "Slack")

    # Variable to represent the total slack
    totSlack = m.addVar(name="totSlack")

    # Variables to count the total shifts worked by each worker
    totShifts = {}
    for w in workers:
        totShifts[w] = m.addVar(name=w + "TotShifts")

    # Update model to integrate new variables
    m.update()

    # Constraint: assign exactly shiftRequirements[s] workers to each shift s,
    # plus the slack
    for s in shifts:
        m.addConstr(slacks[s] + gurobipy.quicksum(x[w, s] for w, s in availability.select('*', s)) == shiftRequirements[s], s)

    # Constraint: set totSlack equal to the total slack
    m.addConstr(totSlack == gurobipy.quicksum(slacks[s] for s in shifts), "totSlack")

    # Constraint: compute the total number of shifts for each worker
    for w in workers:
        m.addConstr(totShifts[w] == gurobipy.quicksum(x[w, s] for w, s in availability.select(w, '*')), "totShifts" + w)

    # Objective: minimize the total slack
    # Note that this replaces the previous 'pay' objective coefficients
    # m.setObjective(totSlack)  Altered from Gurobi to minimise cost as a secondary objective to minimising the total slack

    # Optimize
    def solveAndPrint(msg, type_, data_set):
        data_set.send_progress_message(msg)
        m.optimize()
        status = m.status
        if status == gurobipy.GRB.status.INF_OR_UNBD or status == gurobipy.GRB.status.INFEASIBLE or status == gurobipy.GRB.status.UNBOUNDED:
            data_set.send_progress_message('The model cannot be solved because it is infeasible or unbounded')

        if status != gurobipy.GRB.status.OPTIMAL:
            data_set.send_progress_message('Optimization was stopped with status %s' % (str(status)))

        # Print total slack and the number of shifts worked for each worker
        data_set.send_progress_message('Total number of shifts not covered : %s' % (str(totSlack.x)))
        for w in workers:
            data_set.send_progress_message('%s worked %s shifts' % (str(w), str(totShifts[w].x)))

        cost = 0
        data_set.query(type_).delete()
        for s, d in availability:
            if x[s, d].x >= 1:
                cost += data_set.query(Staff).filter(Staff.name == s).one().rate
                data_set.add(type_(s, d))
        data_set.send_progress_message('Total cost = %s' % (str(cost)))

    solveAndPrint('<br>Minimising the total amount of understaffing then cost', ShiftsWorkedForLeastCost, data_set)

    # Constrain the slack by setting its upper and lower bounds
    totSlack.ub = totSlack.x
    totSlack.lb = totSlack.x

    # Variable to count the average number of shifts worked
    avgShifts = m.addVar(name="avgShifts")

    # Variables to count the difference from average for each worker;
    # note that these variables can take negative values.
    diffShifts = {}
    for w in workers:
        diffShifts[w] = m.addVar(lb=-gurobipy.GRB.INFINITY, ub=gurobipy.GRB.INFINITY, name=w + "Diff")

    # Update model to integrate new variables
    m.update()

    # Constraint: compute the average number of shifts worked
    m.addConstr(len(workers) * avgShifts == gurobipy.quicksum(totShifts[w] for w in workers), "avgShifts")

    # Constraint: compute the difference from the average number of shifts
    for w in workers:
        m.addConstr(diffShifts[w] == totShifts[w] - avgShifts, w + "Diff")

    # Objective: minimize the sum of the square of the difference from the
    # average number of shifts worked
    m.setObjective(gurobipy.quicksum(diffShifts[w] * diffShifts[w] for w in workers))

    # Optimize
    solveAndPrint('<br>Minimising the square of the differences from the average number of shifts worked', ShiftsWorkedForEvenWorkings, data_set)
