from __future__ import with_statement

# -----------------------------------------------
# -----------------------------------------------
class PlotterWarning(UserWarning):
    pass

class DummyPlotter(object):
    """Provide a placeholder for the Plotter class. This class can be used when matplotlib 
    could not be imported, or intentionally to disable plotting without changing other code.
    The class contains the same methods as the real Plotter class, but all methods issue 
    a warning when called."""
    send_warning = False
    
for name in ["__init__", "__len__", "__getitem__", "save", "update", "close", 
             "layout", "active_axes", "config_axes", "add_axes", "set_title", "legend", 
             "pie_chart", "bar_chart", "pareto_chart", "histogram", 
             "box_plot", "run_chart", "function_plot", "simple_plot"]:
    def func(*args, **kwargs):
        if DummyPlotter.send_warning:
            warnings.warn("plotting unavailable (matplotlib not found)", PlotterWarning)
    func.__name__ = name
    setattr(DummyPlotter, name, func)
    
# -----------------------------------------------
# -----------------------------------------------
try:
    from matplotlib.pyplot import figure as Figure, close as close_figure
    from matplotlib.patches import Rectangle
except ImportError:
    import warnings
    warnings.simplefilter('always', PlotterWarning)
    warnings.warn("plotting unavailable (matplotlib not found)", PlotterWarning)
    Plotter = DummyPlotter
    Plotter.dummy = DummyPlotter
    Plotter.send_warning = True
else:
    from contextlib import contextmanager
    from collections import deque
    from math import sqrt
    from khronos.utils import Namespace
    
    class Plotter(object):
        """This class is responsible for managing the layout of a figure, and also implementing 
        the plotting commands for simple statistics plots."""
        dummy = DummyPlotter
        __slots__ = ["title", "figure", "rows", "cols", "hspace", "vspace"]
        
        def __init__(self, title=None, rows=0, cols=0, hspace=1.0/8.0, vspace=1.0/8.0):
            self.title = title
            self.figure = None
            self.rows = rows
            self.cols = cols
            self.hspace = hspace
            self.vspace = vspace
            
        def __len__(self):
            if self.figure is not None:
                return len(self.figure.axes)
            return 0
            
        def __getitem__(self, index):
            return self.figure.axes[index]
            
        def save(self, *args, **kwargs):
            self.figure.savefig(*args, **kwargs)
            
        def update(self):
            if self.figure is not None:
                self.figure.show()
                
        def close(self):
            if self.figure is not None:
                close_figure(self.figure)
                self.figure = None
                
        def layout(self, rows=None, cols=None, hspace=None, vspace=None, update=True):
            self.__update_shape(rows, cols)
            self.__reposition_axes(hspace, vspace)
            if update:
                self.update()
                
        def __update_shape(self, rows, cols):
            if rows is not None: self.rows = rows
            if cols is not None: self.cols = cols
            
            n = len(self)
            if self.rows <= 0:
                if self.cols <= 0:
                    self.rows = int(sqrt(n) + 0.5)
                    self.cols = self.rows + (1 if n > self.rows**2 else 0)
                else:
                    self.rows = int(float(n) / self.cols + 0.5)
            else:
                if self.cols <= 0:
                    self.cols = int(0.5 + float(n) / self.rows)
                elif self.rows * self.cols < n:
                    self.__update_shape(0, 0)
                    
        def __reposition_axes(self, hspace, vspace):
            if hspace is not None and 0.0 <= hspace <= 1.0: self.hspace = hspace
            if vspace is not None and 0.0 <= vspace <= 1.0: self.vspace = vspace
            
            plot_width, plot_height = self.__axes_dimensions()
            space_width  = self.hspace * plot_width
            space_height = self.vspace * plot_height
            
            n = len(self)
            index = 0
            y_pos = 1.0 - plot_height - space_height
            for y in xrange(self.rows):
                x_pos = space_width
                for x in xrange(self.cols):
                    axes = self.figure.axes[index]
                    axes.set_position([x_pos, y_pos, plot_width, plot_height])
                    index += 1
                    if index == n:
                        return
                    x_pos += plot_width + 2 * space_width
                y_pos -= plot_height + 2 * space_height
                
        def __axes_dimensions(self):
            htotal = self.cols * (1 + 2 * self.hspace)
            vtotal = self.rows * (1 + 2 * self.vspace)
            plot_width  = 1.0 / htotal
            plot_height = 1.0 / vtotal
            return plot_width, plot_height
            
        @contextmanager
        def active_axes(self, axes=None, title=None, xlabel=None, ylabel=None, 
                        xlimit=None, ylimit=None, show=True):
            if axes is None:
                axes = self.add_axes(title, xlabel, ylabel)
            else:
                self.config_axes(axes, title, xlabel, ylabel)
            yield axes
            axes.set_xlim(xlimit)
            axes.set_ylim(ylimit)
            if show:
                axes.figure.show()
                
        def config_axes(self, axes=None, title=None, xlabel=None, ylabel=None, 
                        xlimit=None, ylimit=None):
            if axes   is     None: axes = self.figure.axes[-1]
            if title  is not None: axes.set_title(title)
            if xlabel is not None: axes.set_xlabel(xlabel)
            if ylabel is not None: axes.set_ylabel(ylabel)
            if xlimit is not None: axes.set_xlim(xlimit)
            if ylimit is not None: axes.set_ylim(ylimit)
            
        def add_axes(self, title=None, xlabel=None, ylabel=None, xlimit=None, ylimit=None):
            if self.figure is None:
                self.figure = Figure()
                self.figure.plotter = self
                if self.title is not None:
                    self.figure.suptitle(self.title)
            axes = self.figure.add_subplot(1, 1, 1, label=str(len(self.figure.axes)))
            self.config_axes(axes, title, xlabel, ylabel, xlimit, ylimit)
            self.layout(update=False)
            return axes
            
        def set_title(self, title, update=True):
            self.title = title
            if self.figure is not None:
                self.figure.suptitle(title)
            if update:
                self.update()
                
        def legend(self, axes, *args, **kwargs):
            axes.legend(*args, **kwargs)
            self.update()
            
        # ++++++++++++++++++++++++++++++++++++++++++++++
        def pie_chart(self, values, freqs, axes=None, title=None, show=True):
            with self.active_axes(axes, title, show=show) as axes:
                axes.pie(freqs, labels=values, autopct="%1.1f%%", shadow=True)
            return axes
            
        def bar_chart(self, values, freqs, axes=None, color="#00FF00", 
                      title=None, xlabel="", ylabel="Abs frequency", 
                      xlimit=None, ylimit=None, show=True):
            with self.active_axes(axes, title, xlabel, ylabel, xlimit, ylimit, show) as axes:
                data = self.__prepare_bar_chart(values, freqs)
                axes.xaxis.set_ticks(data.tick_locations)
                axes.xaxis.set_ticklabels(values)
                axes.bar(data.left, freqs, width=data.bar_width, color=color)
            return axes
            
        def pareto_chart(self, values, freqs, axes=None, color="#00FF00", 
                         title=None, xlabel="", ylabel="Rel frequency", 
                         xlimit=None, ylimit=(0.0, 1.0), show=True):
            with self.active_axes(axes, title, xlabel, ylabel, xlimit, ylimit, show) as axes:
                data = self.__prepare_pareto_chart(values, freqs)
                axes.xaxis.set_ticks(data.tick_locations)
                axes.xaxis.set_ticklabels(data.value_list)
                axes.bar(data.left, data.freq_list, width=data.bar_width, color=color)
                axes.plot(data.xs, data.ys, "r-", label="Cumulative frequency")
            return axes
            
        def histogram(self, values, freqs=None, bins=10, axes=None, color="#0000FF", 
                      title=None, xlabel="", ylabel="Frequency", 
                      xlimit=None, ylimit=None, show=True):
            with self.active_axes(axes, title, xlabel, ylabel, xlimit, ylimit, show) as axes:
                data = self.__prepare_histogram(values, freqs, bins)
                axes.bar(range(bins), height=data.height, width=1, color=color)
                axes.text(bins, max(data.height), 
                          "Min = %f\nMax = %f\nBin width = %f" % (data.minimum, 
                                                                  data.maximum, 
                                                                  data.width),
                          va="top", ha="right",
                          bbox=dict(boxstyle="square", 
                                    ec=(0.0, 0.0, 0.0),
                                    fc=(0.9, 0.9, 0.9)))
            return axes
            
        def box_plot(self, values, axes=None, #position=1, width=0.5, label="", 
                     title=None, xlabel="", ylabel="", 
                     xlimit=None, ylimit=None, show=True):
            with self.active_axes(axes, title, xlabel, ylabel, xlimit, ylimit, show) as axes:
                axes.boxplot(values) #, positions=(position,), widths=(1.0,))
            return axes
            
        def run_chart(self, values, times, numeric=True, mean=None, stddev=None, fill=False, 
                      axes=None, color="#FF0000", label="_nolegend_", title=None, legend=False, 
                      xlabel="Time", ylabel="", xlimit=None, ylimit=None, show=True):
            with self.active_axes(axes, title, xlabel, ylabel, xlimit, ylimit, show) as axes:
                data = self.__prepare_run_chart(values, times, numeric)
                if not numeric:
                    # label the y axis if the data is not numeric
                    axes.yaxis.set_ticks(data.y_ticklocs)
                    axes.yaxis.set_ticklabels(data.y_ticklabels)
                elif mean is not None:
                    # plot a thick horizontal line at the mean
                    label_mean = "_nolegend_"
                    if label != "_nolegend_":
                        label_mean = label + " (wmean)"
                    axes.hlines([mean], xmin=data.xmin, xmax=data.xmax, colors=color, 
                                linestyle="dashed", linewidth=2, label=label_mean, zorder=2)
                    # draw transparent rectangles between [mean - i * stddev, mean + i * stddev]
                    if stddev is not None:
                        for i in xrange(1, 4):
                            rect = Rectangle((data.xmin, mean - i * stddev), 
                                             width=(data.xmax - data.xmin), 
                                             height=(2 * i * stddev), facecolor=color, 
                                             alpha=0.2, zorder=1)
                            axes.add_patch(rect)
                # the actual run chart plot
                if fill:
                    raise NotImplementedError() # TODO: implement fill graph
                else:
                    axes.plot(data.xs, data.ys, color, label=label, zorder=3)
                axes.grid(True)
                if legend:
                    axes.legend(loc="best")
            return axes
            
        def function_plot(self, function, start, stop, step=1.0, axes=None, 
                          color="#000000", label="_nolegend_", title=None, legend=False,
                          xlabel="Time", ylabel="", xlimit=None, ylimit=None, show=True):
            with self.active_axes(axes, title, xlabel, ylabel, xlimit, ylimit, show) as axes:
                xs = []
                ys = []
                x = start
                while x <= stop:
                    xs.append(x)
                    ys.append(function(x))
                    x += step
                axes.plot(xs, ys, color, label=label)
                if legend:
                    axes.legend(loc="best")
            return axes
            
        def simple_plot(self, xs, ys, axes=None, color="#000000", 
                        linestyle="-", linewidth=2, marker="",
                        label="_nolegend_", title=None, legend=False, 
                        xlabel="", ylabel="", xlimit=None, ylimit=None, show=True):
            with self.active_axes(axes, title, xlabel, ylabel, xlimit, ylimit, show) as axes:
                axes.plot(deque(xs), deque(ys), color=color, marker=marker, label=label, 
                          linestyle=linestyle, linewidth=linewidth)
                if legend:
                    axes.legend(loc="best")
            return axes
            
        # ++++++++++++++++++++++++++++++++++++++++++++++
        def __prepare_bar_chart(self, values, freqs, bar_width=1.0, bar_space=0.5):
            half_col = bar_width / 2.0
            left = [(bar_width + bar_space) * x for x in xrange(len(freqs))]
            tick_locations = [l + half_col for l in left]
            return Namespace(left=left, tick_locations=tick_locations, bar_width=bar_width)
            
        def __prepare_pareto_chart(self, values, freqs):
            total = float(sum(freqs))
            normalized = [(f / total, v) for v, f in zip(values, freqs)]
            sorted_items = sorted(normalized, reverse=True)
            cumulative = [0.0]
            for f, _ in sorted_items:
                cumulative.append(cumulative[-1] + f)
            
            data = self.__prepare_bar_chart(values, freqs, bar_space=0.0)
            data.freq_list  = [f for f, _ in sorted_items]
            data.value_list = [v for _, v in sorted_items]
            data.xs = [data.bar_width * x for x in xrange(len(freqs) + 1)]
            data.ys = cumulative
            return data
            
        def __prepare_histogram(self, values, freqs, bins):
            if freqs is None:
                freqs = [1.0] * len(values)
            items = sorted(zip(values, freqs))
            minimum = items[ 0][0]
            maximum = items[-1][0]
            total_freq = sum(freqs)
            bin_span = float(maximum - minimum) / bins
            bins_start = [minimum + bin_span * x for x in xrange(bins)]
            freq_normalizer = 1.0 / (bin_span * total_freq)
            
            bin_freq = [0.0] * bins
            cur_bin = bins - 1
            while items:
                v, f = items.pop()
                while v < bins_start[cur_bin]:
                    cur_bin -= 1
                bin_freq[cur_bin] += f * freq_normalizer
            
            return Namespace(height=bin_freq, 
                             minimum=minimum, 
                             maximum=maximum, 
                             width=bin_span)
                             
        def __prepare_run_chart(self, values, times, numeric):
            # copy iterables to keep the original lists unharmed
            values = deque(values)
            times = deque(times)
            
            xs = deque([times.popleft()])
            ys = deque([values.popleft()])
            prev_y = ys[0]
            for y, t in zip(values, times):
                xs.extend((t, t))
                ys.extend((prev_y, y))
                prev_y = y
            
            data = Namespace(xs=xs, ys=ys, xmin=xs[0], xmax=xs[-1])
            if not numeric:
                # map objects to integer y values if the tseries is not numeric
                y_set = list(set(ys))
                y_set.sort()
                y_mapping = dict(zip(y_set, xrange(len(y_set))))
                data.ys = deque([y_mapping[y] for y in ys])
                # prepare y ticks explaining the translation from objects to integers
                y_ticklabels = [(i, v) for v, i in y_mapping.iteritems()]
                y_ticklabels.sort()
                data.y_ticklabels = [v for _, v in y_ticklabels]
                data.y_ticklocs   = [i for i, _ in y_ticklabels]
            return data
            
def get_plotter(target=None):
    """get_plotter(plotter_or_axes) -> (plotter, axes)"""
    if target is None:
        return Plotter(), None
    if isinstance(target, (Plotter, DummyPlotter)):
        return target, None
    # If target is neither None or a Plotter object, it should be an Axes object.
    return target.figure.plotter, target
    
