"""
Author:      www.tropofy.com

Copyright 2013 Tropofy Pty Ltd, all rights reserved.

This source file is part of Tropofy and governed by the Tropofy terms of service
available at: http://www.tropofy.com/terms_of_service.html

This source file is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE. See the license files for details.
"""

from tropofy.app import AppWithDataSets, Parameter, Step, StepGroup
from tropofy.widgets import StaticImage, ExecuteFunction, ParameterForm
import io
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pylab as pl


class OutputPlot(StaticImage):
    def get_file_path(self, data_set):
        return data_set.get_image_path('output.png')


def validate_value_g_zero(value):
    return True if value > 0 else "Value must be > 0."


class MyApp(AppWithDataSets):
    def get_name(self):
        return "Plotting in 3d"

    def get_gui(self):
        return [
            StepGroup(
                name='Input',
                steps=[
                    Step(name='Parameters', widgets=[
                        {"widget": ParameterForm(), "cols": 6},
                    ]),
                    Step(name='Create Plot', widgets=[CreatePlot()]),
                ]
            ),
            StepGroup(
                name='Output',
                steps=[
                    Step(name='Output', widgets=[OutputPlot()])
                ]
            ),
        ]

    def get_parameters(self):
        return [
            Parameter(name='mesh_size', label='Mesh Size', default=0.25, allowed_type=float, validator=validate_value_g_zero),
            Parameter(name='x_min', label='X-Min', default=-4, allowed_type=float),
            Parameter(name='x_max', label='X-Max', default=4, allowed_type=float),
            Parameter(name='y_min', label='Y-Min', default=-4, allowed_type=float),
            Parameter(name='y_max', label='Y-Max', default=4, allowed_type=float),
        ]

    def get_examples(self):
        return {"Demo data set": load_example_data}

    def get_icon_url(self):
        return 'http://www.tropofy.com/static/css/img/tropofy_example_app_icons/histogram_3d.png'

    def get_home_page_content(self):
        return {
            'content_app_name_header': '''
            <div>
            <span style="vertical-align: middle;">Plotting in 3D</span>
            <img src="http://www.tropofy.com/static/css/img/tropofy_example_app_icons/histogram_3d.png" alt="main logo" style="width:15%">
            </div>''',

            'content_single_column_app_description': '''

            <p>Into scientific computing? Dying to see that 3D plot of Z=sin(sqrt(x^2+y^2))) you've always heard about? Need the power to vary up the mesh? What about those x, and y ranges?</p>
            <p>Big mesh, little mesh, we've got them all!</p>

            <p>Check out how Tropofy integrates with <a href="http://www.scipy.org/">Scipy</a>, <a href="http://www.numpy.org/">NumPy</a> and <a href="http://matplotlib.org/">Matplotlib</a>
            in this simple app.</p>
            ''',

            'content_row_4_col_1_content': '''
            This app was created using the <a href="http://www.tropofy.com" target="_blank">Tropofy platform</a>.
            '''
        }


class CreatePlot(ExecuteFunction):
    def get_button_text(self):
        return "Create plot of Z = sin(sqrt(x^2 + y^2))"

    def execute_function(self, data_set):
        fig = pl.figure()
        ax = Axes3D(fig)
        X = np.arange(data_set.get_param('x_min'), data_set.get_param('x_max'), data_set.get_param('mesh_size'))
        Y = np.arange(data_set.get_param('y_min'), data_set.get_param('y_max'), data_set.get_param('mesh_size'))
        X, Y = np.meshgrid(X, Y)
        R = np.sqrt(X ** 2 + Y ** 2)
        Z = np.sin(R)

        ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=pl.cm.hot)
        ax.contourf(X, Y, Z, zdir='z', offset=-2, cmap=pl.cm.hot)
        ax.set_zlim(-2, 2)

        plt.title('Plot of $Z = sin\sqrt{x^2+y^2}$')

        buf = io.BytesIO()
        plt.savefig(buf, format="png", bbox_inches=0)
        buf.seek(0)
        img = Image.open(buf)
        data_set.save_image(name="output.png", image=img)
        data_set.send_progress_message("Plot successfully created. Go to next step to view it.")


def load_example_data(data_set):
    pass
