#! /usr/bin/env python

# from __future__ import with_statement

"""\
Usage: %prog [options] spcfile.spc

Make a SUSY mass spectrum plot.

Output is currently only available in the make-plots .dat format; work on the
alternative PGF/TikZ LaTeX graphics language is also planned.

TODOs:
  * Provide a PGF/TikZ output format option
  * Also allow plotting direct from an ISAJET file
  * Allow plotting from a text string to stdin
  * Read plot details from defs file?
"""

class XEdges(object):
    def __init__(self, left, offset=0.0, width=2.0):
        self.offset = offset
        self.left = left + offset
        self.width = width
    @property
    def right(self):
        return self.left + self.width
    @property
    def centre(self):
        return (self.left + self.right)/2.0


class Label(object):
    def __init__(self, text, offset=None):
        self.text = text
        self.offset = None
    def __str__(self):
        return self.text


## *Details classes for representing decided positions in a way independent of output format


class ParticleDetails(object):
    def __init__(self, label, xnom, xoffset, color="black", labelpos="L", mass=None):
        self.label = label
        self.mass = mass
        self.xedges = XEdges(xnom, xoffset)
        self.color = color
        self.labelpos = labelpos


class DecayDetails(object):
    def __init__(self, pidfrom, xyfrom, pidto, xyto, color="gray"): #, thickness=1px, label=None):
        self.pidfrom = pidfrom
        self.xyfrom = xyfrom
        self.pidto = pidto
        self.xyto = xyto
        self.color = color
        #self.label = label


class LabelDetails(object):
    def __init__(self, xy, texlabel, anchor="l", color="black"):
        self.xy = xy
        self.texlabel = texlabel
        ## Add non-TeX-based label rendering via this property, if needed
        self.textlabel = texlabel
        self.anchor = anchor
        self.color = color



XHIGGS = 0.0
XSLEPTON = 5.0
XGAUGINO = 10.0
XSUSYQCD = 15.0

PDETAILS = {
    25 : ParticleDetails(Label(r"$h^0$"), XHIGGS, -0.2, color="blue"),
    35 : ParticleDetails(Label(r"$H^0$"), XHIGGS, -0.2, color="blue"),
    36 : ParticleDetails(Label(r"$A^0$"), XHIGGS, -0.2, color="blue"),
    37 : ParticleDetails(Label(r"$H^\pm$"), XHIGGS, 0.2, color="red"),
    1000011 : ParticleDetails(Label(r"$\tilde{\ell}_\text{L}$"), XSLEPTON, -0.2, color="blue"),
    2000011 : ParticleDetails(Label(r"$\tilde{\ell}_\text{R}$"), XSLEPTON, -0.2, color="blue"),
    1000015 : ParticleDetails(Label(r"$\tilde{\tau}_1$"), XSLEPTON, 0.2, color="red"),
    2000015 : ParticleDetails(Label(r"$\tilde{\tau}_2$"), XSLEPTON, 0.2, color="red"),
    1000012 : ParticleDetails(Label(r"$\tilde{\nu}_\text{L}$"), XSLEPTON, -0.2, color="blue"),
    1000016 : ParticleDetails(Label(r"$\tilde{\nu}_\tau$"), XSLEPTON, 0.2, color="red"),
    1000022 : ParticleDetails(Label(r"$\tilde{\chi}_1^0$"), XGAUGINO, -0.2, color="blue"),
    1000023 : ParticleDetails(Label(r"$\tilde{\chi}_2^0$"), XGAUGINO, -0.2, color="blue"),
    1000025 : ParticleDetails(Label(r"$\tilde{\chi}_3^0$"), XGAUGINO, -0.2, color="blue"),
    1000035 : ParticleDetails(Label(r"$\tilde{\chi}_4^0$"), XGAUGINO, -0.2, color="blue"),
    1000024 : ParticleDetails(Label(r"$\tilde{\chi}_1^\pm$"), XGAUGINO, 0.2, color="red"),
    1000037 : ParticleDetails(Label(r"$\tilde{\chi}_2^\pm$"), XGAUGINO, 0.2, color="red"),
    1000021 : ParticleDetails(Label(r"$\tilde{g}$"), XSUSYQCD, -0.3, color="black!50!blue!30!green"),
    1000001 : ParticleDetails(Label(r"$\tilde{q}_\text{L}$"), XSUSYQCD, -0.1, color="blue"),
    2000001 : ParticleDetails(Label(r"$\tilde{q}_\text{R}$"), XSUSYQCD, -0.1, color="blue"),
    1000005 : ParticleDetails(Label(r"$\tilde{b}_1$"), XSUSYQCD, 0.2, color="black!50!blue!30!green"),
    2000005 : ParticleDetails(Label(r"$\tilde{b}_2$"), XSUSYQCD, 0.2, color="black!50!blue!30!green"),
    1000006 : ParticleDetails(Label(r"$\tilde{t}_1$"), XSUSYQCD, 0.2, color="red"),
    2000006 : ParticleDetails(Label(r"$\tilde{t}_2$"), XSUSYQCD, 0.2, color="red")
}


import pyslha
import sys, optparse
parser = optparse.OptionParser(usage=__doc__, version=pyslha.__version__)
parser.add_option("-o", "--out", metavar="FILE",
                  help="write output to FILE",
                  dest="OUTFILE", default=None)
# TODO: Add pgf/tikz format option
parser.add_option("-f", "--format", choices=["dat"], metavar="FORMAT",
                  help="format in which to write output",
                  dest="FORMAT", default="dat")
parser.add_option("--minbr", "--br", metavar="BR",
                  help="show decay lines for decays with a branching ratio of > BR (default: %default)",
                  dest="DECAYS_MINBR", type=float, default=1.0)
parser.add_option("--labels", choices=["none", "merge", "shift"], metavar="MODE",
                  help="treatment of labels for particle IDs (default: shift)",
                  dest="PARTICLES_LABELS", default="shift")
opts, args = parser.parse_args()
opts.PARTICLES_LABELS_SHOW = (opts.PARTICLES_LABELS != "none")
opts.PARTICLES_LABELS_MERGE = (opts.PARTICLES_LABELS == "merge")
opts.PARTICLES_LABELS_SHIFT = (opts.PARTICLES_LABELS == "shift")
if len(args) != 1:
    parser.print_help()
    sys.exit(1)
opts.INFILE = args[0]

## Choose output file
if opts.OUTFILE is None:
    import os
    o = os.path.basename(opts.INFILE)
    if "." in o:
        o = o[:o.rindex(".")]
    # TODO: Add format-specific suffix
    opts.OUTFILE = o + ".dat"
out = ""


## Read spectrum file
BLOCKS, DECAYS = pyslha.readSLHAFile(opts.INFILE)


## Set mass values in PDETAILS
massblock = BLOCKS["MASS"]
for pid in PDETAILS.keys():
    PDETAILS[pid].mass = abs(massblock.entries[pid])


## Decays
DDETAILS = {}
for pid, detail in sorted(PDETAILS.iteritems()):
    DDETAILS.setdefault(pid, {})
    if DECAYS.has_key(pid):
        xyfrom = (detail.xedges.centre, detail.mass)
        for d in DECAYS[pid].decays:
            if d.br > opts.DECAYS_MINBR:
                for pid2 in d.ids:
                    if PDETAILS.has_key(pid2):
                        xyto = (PDETAILS[pid2].xedges.centre, PDETAILS[pid2].mass)
                        # TODO: Color/thickness by branching ratio
                        DDETAILS[pid][pid2] = DecayDetails(pid, xyfrom, pid2, xyto)


## Labels
PLABELS = []
if opts.PARTICLES_LABELS_SHOW:
    class MultiLabel(object):
        def __init__(self, label=None, x=None, y=None, anchor=None):
            self.labels = [(label, x, y)] or []
            self.anchor = anchor or "l"

        def __len__(self):
            return len(self.labels)

        @property
        def joinedlabel(self):
            return r",\,".join(l[0] for l in self.labels)

        @property
        def avgx(self):
            return sum(l[1] for l in self.labels)/float(len(self))
        @property
        def minx(self):
            return min(l[1] for l in self.labels)/float(len(self))
        @property
        def maxx(self):
            return max(l[1] for l in self.labels)/float(len(self))

        @property
        def avgy(self):
            return sum(l[2] for l in self.labels)/float(len(self))
        @property
        def miny(self):
            return min(l[2] for l in self.labels)/float(len(self))
        @property
        def maxy(self):
            return max(l[2] for l in self.labels)/float(len(self))

        def add(self, label, x, y):
            self.labels.append((label, x, y))
            self.labels = sorted(self.labels, key=lambda l : l[2])
            return self
        def get(self):
            for i in self.labels:
                yield i

    def rel_err(a, b):
        return abs((a-b)/(a+b)/2.0)

    ## Use max mass to work out the height of a text line in mass units
    maxmass = None
    for pid, pdetail in sorted(PDETAILS.iteritems()):
        maxmass = max(pdetail.mass, maxmass)
    text_height_in_mass_units = maxmass/22.0
    ##
    ## Merge colliding labels
    reallabels = []
    for pid, pdetail in sorted(PDETAILS.iteritems()):
        labelx = None
        offset = pdetail.label.offset or 0.2
        anchor = None
        if pdetail.xedges.offset <= 0:
            labelx = pdetail.xedges.left - offset
            anchor = "r"
        else:
            labelx = pdetail.xedges.right + offset
            anchor = "l"
        labely = pdetail.mass
        text = pdetail.label.text
        ## Check for collisions
        collision = False
        if opts.PARTICLES_LABELS_SHIFT or opts.PARTICLES_LABELS_MERGE:
            for i, rl in enumerate(reallabels):
                if anchor == rl.anchor and abs(labelx - rl.avgx) < 0.5:
                    import math
                    if labely > rl.miny - text_height_in_mass_units and labely < rl.maxy + text_height_in_mass_units:
                        reallabels[i] = rl.add(text, labelx, labely)
                        collision = True
                        break
        if not collision:
            reallabels.append(MultiLabel(text, labelx, labely, anchor))
    ## Calculate position shifts and fill PLABELS
    for rl in reallabels:
        if len(rl) == 1 or opts.PARTICLES_LABELS_MERGE:
            PLABELS.append(LabelDetails((rl.avgx, rl.avgy), rl.joinedlabel, anchor=rl.anchor))
        else:
            num_gaps = len(rl)-1
            yrange_old = rl.maxy - rl.miny
            yrange_nom = num_gaps * text_height_in_mass_units
            yrange = max(yrange_old, yrange_nom)
            ydiff = yrange - yrange_old
            for i, (t, x, y) in enumerate(rl.get()):
                ydiff_per_line = ydiff/num_gaps
                # TODO: Further improvement using relative or average positions?
                newy = y + (i - num_gaps/2.0) * ydiff_per_line
                PLABELS.append(LabelDetails((x, newy), t, anchor=rl.anchor))



out = ""
if opts.FORMAT == "dat":

    ## Write plot header
    out += "# SUSY mass/decay spectrum plot, created by pyslha/slhaplot from %s\n" % opts.INFILE
    out += "# http://pypi.python.org/pypi/pyslha\n"
    out += "\n"
    out += "# BEGIN PLOT\n"
    if opts.PARTICLES_LABELS_MERGE:
        ## Need more space if labels are to be merged horizontally
        out += "XMin=-4\n"
        out += "XMax=20\n"
    else:
        out += "XMin=-3\n"
        out += "XMax=19\n"
    # if opts.LOGSCALE:
    #     out += "LogY=1\n"
    # else:
    #     out += "LogY=0\n"
        out += "YMin=0\n"
    out += "#XCustomTicks=1.0	Higgs	6.0	Sleptons	11.0	Gauginos	16.0	Squarks\n"
    out += "XCustomTicks=-10.0	~\n"
    out += "YLabel=Mass / GeV\n"
    out += "DrawSpecialFirst=1\n"
    out += "# END PLOT\n\n"


    ## Mass lines
    for pid, pdetail in sorted(PDETAILS.iteritems()):
        out += """
# BEGIN HISTOGRAM %s
ErrorBars=1
LineWidth=1pt
LineColor=%s
%f	%f	%e	0
# END HISTOGRAM\n""" % ("pid"+str(pid), pdetail.color,
                      pdetail.xedges.left, pdetail.xedges.right,
                      pdetail.mass)
    out += "\n"


    ## Decay arrows
    for pidfrom, todict in sorted(DDETAILS.iteritems()):
        for pidto, dd in sorted(todict.iteritems()):
            out += r"""
# BEGIN SPECIAL decay_%d_%d
\psset{arrowsize=0.1}
\psline[linestyle=dashed,dash=3px 2px,linecolor=%s]{->}\physicscoor(%f,%f)\physicscoor(%f,%f)
# END SPECIAL
""" % (dd.pidfrom, dd.pidto, dd.color, dd.xyfrom[0], dd.xyfrom[1], dd.xyto[0], dd.xyto[1])


    ## Particle labels
    out += "\n\n"
    out += "# BEGIN SPECIAL labels\n"
    for ld in PLABELS:
        out += r"\rput[%s]\physicscoor(%f,%f){\small %s}" % (ld.anchor, ld.xy[0], ld.xy[1], ld.texlabel) + "\n"
    out += "# END SPECIAL\n"

else:
    print "Other formats not currently supported! How did we even get here?!"
    sys.exit(1)


## Write it out
f = sys.stdout
if opts.OUTFILE != "-":
    f = open(opts.OUTFILE, "w")
f.write(out)
if f is not sys.stdout:
    f.close()
