#_PYTHON_INSERT_SAO_COPYRIGHT_HERE_(2010)_
#_PYTHON_INSERT_GPL_LICENSE_HERE_
"""
Classes for plotting, analysis of astronomical data sets
"""

from sherpa.astro.data import DataPHA
from sherpa.plot import DataPlot, ModelPlot, FitPlot, DelchiPlot, ResidPlot, \
    RatioPlot, ChisqrPlot, HistogramPlot, backend, Histogram
from sherpa.plot import ComponentSourcePlot as _ComponentSourcePlot
from sherpa.astro.utils import compile_energy_grid, bounds_check
from sherpa.utils.err import PlotErr, IOErr
from sherpa.utils import parse_expr, dataspace1d, histogram1d, filter_bins
from numpy import iterable, array2string, asarray
from itertools import izip
import logging

warning = logging.getLogger(__name__).warning

__all__ = ('SourcePlot','ARFPlot', 'BkgDataPlot', 'BkgModelPlot', 'BkgFitPlot',
           'BkgSourcePlot', 'BkgDelchiPlot', 'BkgResidPlot', 'BkgRatioPlot',
           'BkgChisqrPlot', 'OrderPlot', 'ModelHistogram', 'BkgModelHistogram')


class ModelHistogram(HistogramPlot):
    "Derived class for creating 1D PHA model histogram plots"
    histo_prefs = backend.get_model_histo_defaults()

    def __init__(self):
        HistogramPlot.__init__(self)
        self.title = 'Model'

    def prepare(self, data, model, stat=None):

        old_filter = parse_expr(data.get_filter())
        old_group = data.grouped
        new_filter = parse_expr(data.get_filter(group=False))
        try:
            if old_group:
                data.ungroup()
                for interval in new_filter:
                    data.notice(*interval)

            (self.xlo, self.y, yerr, xerr,
             self.xlabel, self.ylabel) = data.to_plot(yfunc=model)
            self.y = self.y[1]

            if data.units != 'channel':
                elo, ehi = data._get_ebins(group=False)
            else:
                elo, ehi = (data.channel,data.channel+1.)

            self.xlo = data.apply_filter(elo, data._min)
            self.xhi = data.apply_filter(ehi, data._max)
            if data.units == 'wavelength':
                self.xlo = data._hc/self.xlo
                self.xhi = data._hc/self.xhi

        finally:
            if old_group:
                data.ignore()
                data.group()
                for interval in old_filter:
                    data.notice(*interval)



class SourcePlot(HistogramPlot):
    "Derived class for creating plots of the unconvolved source model"

    histo_prefs = backend.get_model_histo_defaults()

    def __init__(self):
        self.units = None
        self.mask  = None
        HistogramPlot.__init__(self)
        self.title = 'Source'

    def prepare(self, data, src, lo=None, hi=None):
        # Note: src is source model before folding
        if not isinstance(data, DataPHA):
            raise IOErr('notpha', data.name)

        lo, hi = bounds_check(lo, hi)

        self.units = data.units
        if self.units == "channel":
            warning("Channel space is unappropriate for the PHA unfolded" +
                    " source model,\nusing energy.")
            self.units = "energy"

        self.xlabel = data.get_xlabel()
        self.title  = 'Source Model of %s' % data.name
        self.xlo, self.xhi = data._get_indep(filter=False)
        self.mask = filter_bins( (lo,), (hi,), (self.xlo,) )
        self.y = src(self.xlo, self.xhi)
        prefix_quant = 'E'
        quant = 'keV'

        if self.units == "wavelength":
            prefix_quant = '\\lambda'
            quant = '\\AA'
            (self.xlo, self.xhi) = (self.xhi, self.xlo)

        xmid = abs(self.xhi-self.xlo)

        self.xlabel = '%s (%s)' % (self.units.capitalize(), quant)
        self.ylabel = '%s  Photons/sec/cm^2%s'

        if data.plot_fac == 0:
            self.y /= xmid
            self.ylabel = self.ylabel % ('f(%s)' % prefix_quant, '/%s ' % quant)

        elif data.plot_fac == 1:
            self.ylabel = self.ylabel % ('%s f(%s)' % (prefix_quant, prefix_quant), '')

        elif data.plot_fac == 2:
            self.y *= xmid
            self.ylabel = self.ylabel % ('%s^{2} f(%s)' % (prefix_quant, prefix_quant),
                                         ' %s ' % quant)
        else:
            raise PlotErr('plotfac', 'Source', data.plot_fac)


    def plot(self, overplot=False, clearwindow=True):
        xlo = self.xlo
        xhi = self.xhi
        y = self.y

        if self.mask is not None:
            xlo = self.xlo[self.mask]
            xhi = self.xhi[self.mask]
            y = self.y[self.mask]

        Histogram.plot(self, xlo, xhi, y, title=self.title,
                       xlabel=self.xlabel, ylabel=self.ylabel,
                       overplot=overplot, clearwindow=clearwindow)


class ComponentModelPlot(_ComponentSourcePlot, ModelHistogram):

    histo_prefs = backend.get_component_histo_defaults()

    def __init__(self):
        ModelHistogram.__init__(self)

    def __str__(self):
        return ModelHistogram.__str__(self)

    def prepare(self, data, model, stat=None):
        ModelHistogram.prepare(self, data, model, stat)
        self.title = 'Model component: %s' % model.name

    def plot(self, overplot=False, clearwindow=True):
        ModelHistogram.plot(self, overplot, clearwindow)  


class ComponentSourcePlot(_ComponentSourcePlot, SourcePlot):

    histo_prefs = backend.get_component_histo_defaults()

    def __init__(self):
        SourcePlot.__init__(self)

    def __str__(self):
        return SourcePlot.__str__(self)

    def prepare(self, data, model, stat=None):
        SourcePlot.prepare(self, data, model)
        self.title = 'Source model component: %s' % model.name

    def plot(self, overplot=False, clearwindow=True):
        SourcePlot.plot(self, overplot, clearwindow)


class ARFPlot(HistogramPlot):
    "Derived class for creating plots of ancillary response"
    histo_prefs = backend.get_model_histo_defaults()

    def prepare(self, arf, data=None):
        self.xlo = arf.energ_lo
        self.xhi = arf.energ_hi
        self.y = arf.specresp

        self.title = arf.name
        self.xlabel = arf.get_xlabel()
        self.ylabel = arf.get_ylabel()

        if data is not None:
            if not isinstance(data, DataPHA):
                raise PlotErr('notpha', data.name)
            if data.units == "wavelength":
                self.xlabel = 'Wavelength (Angstrom)'
                self.xlo = data._hc/self.xlo
                self.xhi = data._hc/self.xhi


class BkgDataPlot(DataPlot):
    "Derived class for creating plots of background counts"
    def __init__(self):
        DataPlot.__init__(self)


class BkgModelPlot(ModelPlot):
    "Derived class for creating plots of background model"
    def __init__(self):
        ModelPlot.__init__(self)
        self.title = 'Background Model Contribution'

class BkgFitPlot(FitPlot):
    "Derived class for creating plots of background counts with fitted model"
    def __init__(self):
        FitPlot.__init__(self)

class BkgDelchiPlot(DelchiPlot):
    "Derived class for creating background plots of 1D delchi chi ((data-model)/error)"
    def __init__(self):
        DelchiPlot.__init__(self)

class BkgResidPlot(ResidPlot):
    "Derived class for creating background plots of 1D residual (data-model)"
    def __init__(self):
        ResidPlot.__init__(self)

    def prepare(self, data, model, stat):
        ResidPlot.prepare(self, data, model, stat)
        self.title = 'Residuals of %s - Bkg Model' % data.name

class BkgRatioPlot(RatioPlot):
    "Derived class for creating background plots of 1D ratio (data:model)"
    def __init__(self):
        RatioPlot.__init__(self)

    def prepare(self, data, model, stat):
        RatioPlot.prepare(self, data, model, stat)
        self.title = 'Ratio of %s : Bkg Model' % data.name

class BkgChisqrPlot(ChisqrPlot):
    "Derived class for creating background plots of 1D chi**2 ((data-model)/error)**2"
    def __init__(self):
        ChisqrPlot.__init__(self)

class BkgSourcePlot(SourcePlot):
    "Derived class for plotting the background unconvolved source model"
    def __init__(self):
        SourcePlot.__init__(self)


class OrderPlot(ModelHistogram):
    """
    Derived class for creating plots of the convolved source model using 
    selected multiple responses
    """

    def __init__(self):
        self.orders=None
        self.colors=None
        self.use_default_colors=True
        ModelHistogram.__init__(self)

    def prepare(self, data, model, orders=None, colors=None):
        self.orders = data.response_ids

        if orders is not None:
            if iterable(orders):
                self.orders = list(orders)
            else:
                self.orders = [orders]

        if colors is not None:
            self.use_default_colors=False
            if iterable(colors):
                self.colors = list(colors)
            else:
                self.colors = [colors]
        else:
            self.colors=[]
            top_color = '0xffffff'
            bot_color = '0x0000bf'
            num = len(self.orders)
            jump = (int(top_color, 16) - int(bot_color,16))/(num+1)
            for order in self.orders:
                self.colors.append(top_color)
                top_color = hex(int(top_color,16)-jump)

        if not self.use_default_colors and len(colors) != len(orders):
            raise PlotErr('ordercolors', len(orders), len(colors))

        old_filter = parse_expr(data.get_filter())
        old_group = data.grouped

        try:
            if old_group:
                data.ungroup()
                for interval in old_filter:
                    data.notice(*interval)

            self.xlo=[]
            self.xhi=[]
            self.y=[]
            (xlo, y, yerr,xerr,
             self.xlabel, self.ylabel) = data.to_plot(model)
            y = y[1]
            if data.units != 'channel':
                elo, ehi = data._get_ebins(group=False)
                xlo = data.apply_filter(elo, data._min)
                xhi = data.apply_filter(ehi, data._max)
                if data.units == 'wavelength':
                    xlo = data._hc/xlo
                    xhi = data._hc/xhi
            else:
                xhi = xlo + 1.

            for order in self.orders:
                self.xlo.append(xlo)
                self.xhi.append(xhi)
                if len(data.response_ids) > 2:
                    if order < 1 or order > len(model.rhs.orders):
                        raise PlotErr('notorder', order)
                    y = data.apply_filter(model.rhs.orders[order-1])
                    y = data._fix_y_units(y,True)
                    if data.exposure:
                        y = data.exposure * y
                self.y.append(y)

        finally:
            if old_group:
                data.ignore()
                data.group()
                for interval in old_filter:
                    data.notice(*interval)

        self.title = 'Model Orders %s' % str(self.orders)

        if len(self.xlo) != len(self.y):
            raise PlotErr("orderarrfail")


    def plot(self, overplot=False, clearwindow=True):
        default_color = self.histo_prefs['linecolor']
        count = 0
        for xlo, xhi, y, color in izip(self.xlo, self.xhi, self.y, self.colors):
            if count != 0:
                overplot=True
                self.histo_prefs['linecolor']=color
            Histogram.plot(self, xlo, xhi, y, title=self.title,
                           xlabel=self.xlabel, ylabel=self.ylabel,
                           overplot=overplot, clearwindow=clearwindow)
            count += 1

        self.histo_prefs['linecolor'] = default_color

class BkgModelHistogram(ModelHistogram):
    "Derived class for creating 1D background PHA model histogram plots"

    def __init__(self):
        ModelHistogram.__init__(self)


class FluxHistogram(ModelHistogram):
    "Derived class for creating 1D flux distribution plots"
    def __init__(self):
        self.modelvals=None
        self.flux=None
        ModelHistogram.__init__(self)

    def __str__(self):
        vals = self.modelvals
        if self.modelvals is not None:
            vals = array2string(asarray(self.modelvals), separator=',', precision=4, suppress_small=False)
            
        flux = self.flux
        if self.flux is not None:
            flux = array2string(asarray(self.flux), separator=',', precision=4, suppress_small=False)

        return '\n'.join(['modelvals = %s' % vals,'flux = %s' % flux,
                          ModelHistogram.__str__(self)])


    def prepare(self, fluxes, bins):
        y = asarray(fluxes[:,0])
        self.flux = y
        self.modelvals = asarray(fluxes[:,1:])
        self.xlo, self.xhi = dataspace1d(y.min(), y.max(), numbins=bins+1)[:2]
        y = histogram1d(y, self.xlo, self.xhi)
	self.y = y/float(y.max())


class EnergyFluxHistogram(FluxHistogram):
    "Derived class for creating 1D energy flux distribution plots"

    def __init__(self):
        FluxHistogram.__init__(self)
        self.title = "Energy flux distribution"
        self.xlabel = "Energy flux (ergs cm^{-2} sec^{-1})"
        self.ylabel = "Frequency"


class PhotonFluxHistogram(FluxHistogram):
    "Derived class for creating 1D photon flux distribution plots"

    def __init__(self):
        FluxHistogram.__init__(self)
        self.title = "Photon flux distribution"
        self.xlabel = "Photon flux (Photons cm^{-2} sec^{-1})"
        self.ylabel = "Frequency"
