#! /usr/bin/env python

# from __future__ import with_statement

"""\
Usage: %prog [options] <spcfile>

Make a SUSY mass spectrum plot from an SLHA or ISAWIG spectrum file. If the
filename ends with .isa, it will be assumed to be an ISAWIG file, otherwise
it will be assumed to be an SLHA file (for which the normal extension is .spc).

Output is currently available as the make-plots .dat format or as LaTeX source
using the PGF/TikZ graphics package. Both may be processed to make EPS or PDF
images.

TODOs:
  * Be able to use physical (or semi-physical) y-coords in TikZ
  * Get rid of make-plots support?
  * Allow plotting from a text string to stdin
  * Allow user to supply a (replacement) LaTeX preamble
  * Read plot details from defs file?
"""

class XEdges(object):
    def __init__(self, left, offset=0.0, width=2.0):
        self.offset = offset
        self.left = left + offset
        self.width = width
    @property
    def right(self):
        return self.left + self.width
    @property
    def centre(self):
        return (self.left + self.right)/2.0


class Label(object):
    def __init__(self, text, offset=None):
        self.text = text
        self.offset = None
    def __str__(self):
        return self.text


## *Details classes for representing decided positions in a way independent of output format


class ParticleDetails(object):
    def __init__(self, label, xnom, xoffset, color="black", labelpos="L", mass=None):
        self.label = label
        self.mass = mass
        self.xedges = XEdges(xnom, xoffset)
        self.color = color
        self.labelpos = labelpos


class DecayDetails(object):
    def __init__(self, pidfrom, xyfrom, pidto, xyto, color="gray"): #, thickness=1px, label=None):
        self.pidfrom = pidfrom
        self.xyfrom = xyfrom
        self.pidto = pidto
        self.xyto = xyto
        self.color = color
        #self.label = label


class LabelDetails(object):
    def __init__(self, xy, texlabel, anchor="l", color="black"):
        self.xy = xy
        self.texlabel = texlabel
        ## Add non-TeX-based label rendering via this property, if needed
        self.textlabel = texlabel
        self.anchor = anchor
        self.color = color



XHIGGS = 0.0
XSLEPTON = 5.0
XGAUGINO = 10.0
XSUSYQCD = 15.0

PDETAILS = {
    25 : ParticleDetails(Label(r"$h^0$"), XHIGGS, -0.2, color="blue"),
    35 : ParticleDetails(Label(r"$H^0$"), XHIGGS, -0.2, color="blue"),
    36 : ParticleDetails(Label(r"$A^0$"), XHIGGS, -0.2, color="blue"),
    37 : ParticleDetails(Label(r"$H^\pm$"), XHIGGS, 0.2, color="red"),
    1000011 : ParticleDetails(Label(r"$\tilde{\ell}_\text{L}$"), XSLEPTON, -0.2, color="blue"),
    2000011 : ParticleDetails(Label(r"$\tilde{\ell}_\text{R}$"), XSLEPTON, -0.2, color="blue"),
    1000015 : ParticleDetails(Label(r"$\tilde{\tau}_1$"), XSLEPTON, 0.2, color="red"),
    2000015 : ParticleDetails(Label(r"$\tilde{\tau}_2$"), XSLEPTON, 0.2, color="red"),
    1000012 : ParticleDetails(Label(r"$\tilde{\nu}_\text{L}$"), XSLEPTON, -0.2, color="blue"),
    1000016 : ParticleDetails(Label(r"$\tilde{\nu}_\tau$"), XSLEPTON, 0.2, color="red"),
    1000022 : ParticleDetails(Label(r"$\tilde{\chi}_1^0$"), XGAUGINO, -0.2, color="blue"),
    1000023 : ParticleDetails(Label(r"$\tilde{\chi}_2^0$"), XGAUGINO, -0.2, color="blue"),
    1000025 : ParticleDetails(Label(r"$\tilde{\chi}_3^0$"), XGAUGINO, -0.2, color="blue"),
    1000035 : ParticleDetails(Label(r"$\tilde{\chi}_4^0$"), XGAUGINO, -0.2, color="blue"),
    1000024 : ParticleDetails(Label(r"$\tilde{\chi}_1^\pm$"), XGAUGINO, 0.2, color="red"),
    1000037 : ParticleDetails(Label(r"$\tilde{\chi}_2^\pm$"), XGAUGINO, 0.2, color="red"),
    1000021 : ParticleDetails(Label(r"$\tilde{g}$"), XSUSYQCD, -0.3, color="black!50!blue!30!green"),
    1000001 : ParticleDetails(Label(r"$\tilde{q}_\text{L}$"), XSUSYQCD, -0.1, color="blue"),
    2000001 : ParticleDetails(Label(r"$\tilde{q}_\text{R}$"), XSUSYQCD, -0.1, color="blue"),
    1000005 : ParticleDetails(Label(r"$\tilde{b}_1$"), XSUSYQCD, 0.2, color="black!50!blue!30!green"),
    2000005 : ParticleDetails(Label(r"$\tilde{b}_2$"), XSUSYQCD, 0.2, color="black!50!blue!30!green"),
    1000006 : ParticleDetails(Label(r"$\tilde{t}_1$"), XSUSYQCD, 0.2, color="red"),
    2000006 : ParticleDetails(Label(r"$\tilde{t}_2$"), XSUSYQCD, 0.2, color="red")
}


import pyslha
import sys, optparse
parser = optparse.OptionParser(usage=__doc__, version=pyslha.__version__)
parser.add_option("-o", "--out", metavar="FILE",
                  help="write output to FILE",
                  dest="OUTFILE", default=None)
# TODO: Add tikzpdf format option which runs pdflatex/pdfcrop, etc.
# TODO: Add an option to specify a LaTeX preamble file to be inserted
parser.add_option("-f", "--format", choices=["makeplots", "tikz", "tikzfrag"], metavar="FORMAT",
                  help="format in which to write output. 'tikz' produces LaTeX source using the "
                  "TikZ graphics package to render the plot, 'tikzfrag' produces the same but "
                  "with the LaTeX preamble and document lines commented out to make it directly "
                  "includeable as a code fragment in LaTeX document source, and 'makeplots' "
                  "produces a .dat file for processing with the make-plots command.",
                  dest="FORMAT", default="tikz")
parser.add_option("--minbr", "--br", metavar="BR",
                  help="show decay lines for decays with a branching ratio of > BR (default: %default)",
                  dest="DECAYS_MINBR", type=float, default=1.0)
parser.add_option("--labels", choices=["none", "merge", "shift"], metavar="MODE",
                  help="treatment of labels for particle IDs (default: shift)",
                  dest="PARTICLES_LABELS", default="shift")
opts, args = parser.parse_args()
opts.PARTICLES_LABELS_SHOW = (opts.PARTICLES_LABELS != "none")
opts.PARTICLES_LABELS_MERGE = (opts.PARTICLES_LABELS == "merge")
opts.PARTICLES_LABELS_SHIFT = (opts.PARTICLES_LABELS == "shift")
if len(args) != 1:
    parser.print_help()
    sys.exit(1)
opts.INFILE = args[0]

## Choose output file
if opts.OUTFILE is None:
    import os
    o = os.path.basename(opts.INFILE)
    if "." in o:
        o = o[:o.rindex(".")]
    ## Add format-specific suffix
    format_suffix = { "makeplots" : ".dat",  "tikz" : ".tex" }
    opts.OUTFILE = o + format_suffix[opts.FORMAT]


## Read spectrum file
BLOCKS, DECAYS = None, None
if opts.INFILE.endswith(".isa"):
    BLOCKS, DECAYS = pyslha.readISAWIGFile(opts.INFILE)
else:
    BLOCKS, DECAYS = pyslha.readSLHAFile(opts.INFILE)

## Set mass values in PDETAILS
massblock = BLOCKS["MASS"]
for pid in PDETAILS.keys():
    PDETAILS[pid].mass = abs(massblock.entries[pid])


## Decays
DDETAILS = {}
for pid, detail in sorted(PDETAILS.iteritems()):
    DDETAILS.setdefault(pid, {})
    if DECAYS.has_key(pid):
        xyfrom = (detail.xedges.centre, detail.mass)
        for d in DECAYS[pid].decays:
            if d.br > opts.DECAYS_MINBR:
                for pid2 in d.ids:
                    if PDETAILS.has_key(pid2):
                        xyto = (PDETAILS[pid2].xedges.centre, PDETAILS[pid2].mass)
                        # TODO: Color/thickness by branching ratio
                        DDETAILS[pid][pid2] = DecayDetails(pid, xyfrom, pid2, xyto)


## Labels
PLABELS = []
if opts.PARTICLES_LABELS_SHOW:
    class MultiLabel(object):
        def __init__(self, label=None, x=None, y=None, anchor=None):
            self.labels = [(label, x, y)] or []
            self.anchor = anchor or "l"

        def __len__(self):
            return len(self.labels)

        @property
        def joinedlabel(self):
            return r",\,".join(l[0] for l in self.labels)

        @property
        def avgx(self):
            return sum(l[1] for l in self.labels)/float(len(self))
        @property
        def minx(self):
            return min(l[1] for l in self.labels)/float(len(self))
        @property
        def maxx(self):
            return max(l[1] for l in self.labels)/float(len(self))

        @property
        def avgy(self):
            return sum(l[2] for l in self.labels)/float(len(self))
        @property
        def miny(self):
            return min(l[2] for l in self.labels)/float(len(self))
        @property
        def maxy(self):
            return max(l[2] for l in self.labels)/float(len(self))

        def add(self, label, x, y):
            self.labels.append((label, x, y))
            self.labels = sorted(self.labels, key=lambda l : l[2])
            return self
        def get(self):
            for i in self.labels:
                yield i

    def rel_err(a, b):
        return abs((a-b)/(a+b)/2.0)

    ## Use max mass to work out the height of a text line in mass units
    maxmass = None
    for pid, pdetail in sorted(PDETAILS.iteritems()):
        maxmass = max(pdetail.mass, maxmass)
    text_height_in_mass_units = maxmass/22.0
    ##
    ## Merge colliding labels
    reallabels = []
    for pid, pdetail in sorted(PDETAILS.iteritems()):
        labelx = None
        offset = pdetail.label.offset or 0.2
        anchor = None
        if pdetail.xedges.offset <= 0:
            labelx = pdetail.xedges.left - offset
            anchor = "r"
        else:
            labelx = pdetail.xedges.right + offset
            anchor = "l"
        labely = pdetail.mass
        text = pdetail.label.text
        ## Check for collisions
        collision = False
        if opts.PARTICLES_LABELS_SHIFT or opts.PARTICLES_LABELS_MERGE:
            for i, rl in enumerate(reallabels):
                if anchor == rl.anchor and abs(labelx - rl.avgx) < 0.5:
                    import math
                    if labely > rl.miny - text_height_in_mass_units and labely < rl.maxy + text_height_in_mass_units:
                        reallabels[i] = rl.add(text, labelx, labely)
                        collision = True
                        break
        if not collision:
            reallabels.append(MultiLabel(text, labelx, labely, anchor))
    ## Calculate position shifts and fill PLABELS
    for rl in reallabels:
        if len(rl) == 1 or opts.PARTICLES_LABELS_MERGE:
            PLABELS.append(LabelDetails((rl.avgx, rl.avgy), rl.joinedlabel, anchor=rl.anchor))
        else:
            num_gaps = len(rl)-1
            yrange_old = rl.maxy - rl.miny
            yrange_nom = num_gaps * text_height_in_mass_units
            yrange = max(yrange_old, yrange_nom)
            ydiff = yrange - yrange_old
            for i, (t, x, y) in enumerate(rl.get()):
                ydiff_per_line = ydiff/num_gaps
                # TODO: Further improvement using relative or average positions?
                newy = y + (i - num_gaps/2.0) * ydiff_per_line
                PLABELS.append(LabelDetails((x, newy), t, anchor=rl.anchor))



out = ""
## MAKE-PLOTS FORMAT
if opts.FORMAT == "makeplots":

    ## Write plot header
    out += "# SUSY mass/decay spectrum plot, created by pyslha/slhaplot from %s\n" % opts.INFILE
    out += "# http://pypi.python.org/pypi/pyslha\n"
    out += "\n"
    out += "# BEGIN PLOT\n"
    if opts.PARTICLES_LABELS_MERGE:
        ## Need more space if labels are to be merged horizontally
        out += "XMin=-4\n"
        out += "XMax=20\n"
    else:
        out += "XMin=-3\n"
        out += "XMax=19\n"
    # if opts.LOGSCALE:
    #     out += "LogY=1\n"
    # else:
    #     out += "LogY=0\n"
        out += "YMin=0\n"
    out += "#XCustomTicks=1.0	Higgs	6.0	Sleptons	11.0	Gauginos	16.0	Squarks\n"
    out += "XCustomTicks=-10.0	~\n"
    out += "YLabel=Mass / GeV\n"
    out += "DrawSpecialFirst=1\n"
    out += "# END PLOT\n\n"


    ## Mass lines
    for pid, pdetail in sorted(PDETAILS.iteritems()):
        out += """
# BEGIN HISTOGRAM %s
ErrorBars=1
LineWidth=1pt
LineColor=%s
%f	%f	%e	0
# END HISTOGRAM\n""" % ("pid"+str(pid), pdetail.color,
                      pdetail.xedges.left, pdetail.xedges.right,
                      pdetail.mass)
    out += "\n"


    ## Decay arrows
    for pidfrom, todict in sorted(DDETAILS.iteritems()):
        for pidto, dd in sorted(todict.iteritems()):
            out += r"""
# BEGIN SPECIAL decay_%d_%d
\psset{arrowsize=0.1}
\psline[linestyle=dashed,dash=3px 2px,linecolor=%s]{->}\physicscoor(%f,%f)\physicscoor(%f,%f)
# END SPECIAL
""" % (dd.pidfrom, dd.pidto, dd.color, dd.xyfrom[0], dd.xyfrom[1], dd.xyto[0], dd.xyto[1])


    ## Particle labels
    out += "\n\n"
    out += "# BEGIN SPECIAL labels\n"
    for ld in PLABELS:
        out += r"\rput[%s]\physicscoor(%f,%f){\small %s}" % (ld.anchor, ld.xy[0], ld.xy[1], ld.texlabel) + "\n"
    out += "# END SPECIAL\n"



## TIKZ FORMAT
if "tikz" in opts.FORMAT:

    c = ""
    if opts.FORMAT == "tikzfrag":
        c = "%"

    ## Write LaTeX header
    out += "%% http://pypi.python.org/pypi/pyslha\n"
    out += c + "\\documentclass[11pt]{article}\n"
    # TODO: Insert user-specified preamble file
    out += c + "\\usepackage[osf]{mathpazo}\n"
    out += c + "\\usepackage{amsmath,amssymb}\n"
    out += c + "\\usepackage[landscape]{geometry}\n"
    #
    out += c + "\\usepackage{tikz}\n"
    out += c + "\\pagestyle{empty}\n"
    out += c + "\n"
    out += c + "\\begin{document}\n"
    out += c + "\\thispagestyle{empty}\n\n"

    ## Get coord space size: horizontal range is fixed by make-plots
    xmin = -3.0
    xmax = 19.0
    if opts.PARTICLES_LABELS_MERGE:
        ## Need more space if labels are to be merged horizontally
        xmin -= 1.0
        xmax += 1.0
    xdiff = xmax - xmin

    aspectratio = 0.7 #0.618
    ydiff = aspectratio * xdiff
    ymin = 0.0
    ymax = ymin + ydiff
    ## Get range of masses needed
    # import operator
    # getmass = operator.attrgetter("mass")
    maxmass = max(pd.mass for pid, pd in PDETAILS.iteritems())
    maxdisplaymass = maxmass * 1.1
    if maxdisplaymass % 100 != 0:
        maxdisplaymass = ((maxdisplaymass + 100) // 100) * 100
    yscale = (ymax-ymin)/maxdisplaymass

    ## Write TikZ header
    out += "\\begin{tikzpicture}[scale=0.6]\n"

    ## Draw the plot boundary and y-ticks
    out += "  %% Frame\n"
    out += "  \\draw (%f,%f) rectangle (%f,%f);\n" % (xmin, ymin, xmax, ymax)
    out += "  %% y-ticks\n"
    for mtick in xrange(0, int(maxdisplaymass) + 100, 100):
        ytick = mtick * yscale
        out += "  \\draw (%f,%f) node[left] {%d};\n" % (xmin, ytick, mtick)
        if mtick > 0 and mtick < maxdisplaymass:
            ## The 0.3 needs to be in the plot coords
            out += "  \\draw (%f,%f) -- (%f,%f);\n" % (xmin+0.3, ytick, xmin, ytick)
    out += "  \\draw (%f,%f) node[left,rotate=90] {Mass / GeV};\n" % (xmin-2.0, ymax)

    ## Decay arrows
    out += "\n  %% Decay arrows\n"
    for pidfrom, todict in sorted(DDETAILS.iteritems()):
        for pidto, dd in sorted(todict.iteritems()):
            out += "  %% decay_%d_%d\n" % (dd.pidfrom, dd.pidto)
            out += "  \\draw[-stealth,thick,dashed,color=%s] (%f,%f) -- (%f,%f);\n" % \
                (dd.color, dd.xyfrom[0], yscale*dd.xyfrom[1], dd.xyto[0], yscale*dd.xyto[1])

    ## Draw mass lines
    out += "\n  %% Particle lines\n"
    for pid, pdetail in sorted(PDETAILS.iteritems()):
        y = pdetail.mass*yscale
        out += "  %% pid%s\n" % str(pid)
        out += "  \\draw[color=%s,thick] (%f,%f) -- (%f,%f);\n" % \
            (pdetail.color, pdetail.xedges.left, y, pdetail.xedges.right, y)

    ## Particle labels
    out += "\n  %% Particle labels\n"
    for ld in PLABELS:
        anchors_pstricks_tikz = { "r" : "left", "l" : "right" }
        out += "  \\draw (%f,%f) node[%s] {\small %s};\n" % \
            (ld.xy[0], yscale*ld.xy[1], anchors_pstricks_tikz[ld.anchor], ld.texlabel)

    ## Write TikZ footer
    out += "\end{tikzpicture}\n\n"

    ## Write LaTeX footer
    out += c + "\end{document}\n"



## UNRECOGNISED FORMAT!
else:
    print "Other formats not currently supported! How did we even get here?!"
    sys.exit(1)


## Write it out
f = sys.stdout
if opts.OUTFILE != "-":
    f = open(opts.OUTFILE, "w")
f.write(out)
if f is not sys.stdout:
    f.close()
