#! /usr/bin/env python

# from __future__ import with_statement

"""\
Usage: %prog [options] <spcfile>

Make a SUSY mass spectrum plot from an SLHA or ISAWIG spectrum file. If the
filename ends with .isa, it will be assumed to be an ISAWIG file, otherwise
it will be assumed to be an SLHA file (for which the normal extension is .spc).

Output is currently available as the make-plots .dat format or as LaTeX source
using the PGF/TikZ graphics package. Both may be processed to make EPS or PDF
images.

TODOs:
  * Allow user control over aspect ratio / geometry
  * PNG output (use PIL if available?)
  * Use proper distinction between physical, plot-logical, and plot output coords
  * Use scaling to allow the y coordinates to be in units of 100 GeV in TikZ output.
  * Distribute decay arrow start/end positions along mass lines rather than always to/from their centres?
  * Drop make-plots support?
"""

class XEdges(object):
    def __init__(self, left, offset=0.0, width=2.0):
        self.offset = offset
        self.left = left + offset
        self.width = width
    @property
    def right(self):
        return self.left + self.width
    @property
    def centre(self):
        return (self.left + self.right)/2.0


class Label(object):
    def __init__(self, text, offset=None):
        self.text = text
        self.offset = None
    def __str__(self):
        return self.text


## Details classes for representing decided positions in a way independent of output format


class ParticleDetails(object):
    def __init__(self, label, xnom, xoffset, color="black", labelpos="L", mass=None):
        self.label = label
        self.mass = mass
        self.xedges = XEdges(xnom, xoffset)
        self.color = color
        self.labelpos = labelpos


class DecayDetails(object):
    def __init__(self, pidfrom, xyfrom, pidto, xyto, br, color="gray"): #, thickness=1px, label=None):
        self.pidfrom = pidfrom
        self.xyfrom = xyfrom
        self.pidto = pidto
        self.xyto = xyto
        self.br = br
        self.color = color
        #self.label = label


class LabelDetails(object):
    def __init__(self, xy, texlabel, anchor="l", color="black"):
        self.xy = xy
        self.texlabel = texlabel
        ## Add non-TeX-based label rendering via this property, if needed
        self.textlabel = texlabel
        self.anchor = anchor
        self.color = color



XHIGGS = 0.0
XSLEPTON = 5.0
XGAUGINO = 10.0
XSUSYQCD = 15.0

PDETAILS = {
    25 : ParticleDetails(Label(r"$h^0$"), XHIGGS, -0.2, color="blue"),
    35 : ParticleDetails(Label(r"$H^0$"), XHIGGS, -0.2, color="blue"),
    36 : ParticleDetails(Label(r"$A^0$"), XHIGGS, -0.2, color="blue"),
    37 : ParticleDetails(Label(r"$H^\pm$"), XHIGGS, 0.2, color="red"),
    1000011 : ParticleDetails(Label(r"$\tilde{\ell}_\text{L}$"), XSLEPTON, -0.2, color="blue"),
    2000011 : ParticleDetails(Label(r"$\tilde{\ell}_\text{R}$"), XSLEPTON, -0.2, color="blue"),
    1000015 : ParticleDetails(Label(r"$\tilde{\tau}_1$"), XSLEPTON, 0.2, color="red"),
    2000015 : ParticleDetails(Label(r"$\tilde{\tau}_2$"), XSLEPTON, 0.2, color="red"),
    1000012 : ParticleDetails(Label(r"$\tilde{\nu}_\text{L}$"), XSLEPTON, -0.2, color="blue"),
    1000016 : ParticleDetails(Label(r"$\tilde{\nu}_\tau$"), XSLEPTON, 0.2, color="red"),
    1000022 : ParticleDetails(Label(r"$\tilde{\chi}_1^0$"), XGAUGINO, -0.2, color="blue"),
    1000023 : ParticleDetails(Label(r"$\tilde{\chi}_2^0$"), XGAUGINO, -0.2, color="blue"),
    1000025 : ParticleDetails(Label(r"$\tilde{\chi}_3^0$"), XGAUGINO, -0.2, color="blue"),
    1000035 : ParticleDetails(Label(r"$\tilde{\chi}_4^0$"), XGAUGINO, -0.2, color="blue"),
    1000024 : ParticleDetails(Label(r"$\tilde{\chi}_1^\pm$"), XGAUGINO, 0.2, color="red"),
    1000037 : ParticleDetails(Label(r"$\tilde{\chi}_2^\pm$"), XGAUGINO, 0.2, color="red"),
    1000039 : ParticleDetails(Label(r"$\tilde{G}$"), XGAUGINO,  0.15, color="black!50!blue!30!green"),
    1000021 : ParticleDetails(Label(r"$\tilde{g}$"), XSUSYQCD, -0.3, color="black!50!blue!30!green"),
    1000001 : ParticleDetails(Label(r"$\tilde{q}_\text{L}$"), XSUSYQCD, -0.1, color="blue"),
    2000001 : ParticleDetails(Label(r"$\tilde{q}_\text{R}$"), XSUSYQCD, -0.1, color="blue"),
    1000005 : ParticleDetails(Label(r"$\tilde{b}_1$"), XSUSYQCD, 0.2, color="black!50!blue!30!green"),
    2000005 : ParticleDetails(Label(r"$\tilde{b}_2$"), XSUSYQCD, 0.2, color="black!50!blue!30!green"),
    1000006 : ParticleDetails(Label(r"$\tilde{t}_1$"), XSUSYQCD, 0.2, color="red"),
    2000006 : ParticleDetails(Label(r"$\tilde{t}_2$"), XSUSYQCD, 0.2, color="red")
}


import pyslha
import sys, optparse
parser = optparse.OptionParser(usage=__doc__, version=pyslha.__version__)
parser.add_option("-o", "--out", metavar="FILE",
                  help="write output to FILE",
                  dest="OUTFILE", default=None)
parser.add_option("-f", "--format", metavar="FORMAT",
                  choices=["makeplots", "tikz", "tikzfrag", "tikzpdf"],
                  help="format in which to write output. 'tikz' produces LaTeX source using the "
                  "TikZ graphics package to render the plot, 'tikzfrag' produces the same but "
                  "with the LaTeX preamble and document lines commented out to make it directly "
                  "includeable as a code fragment in LaTeX document source, 'tikzpdf' produces a "
                  "PDF file created by running pdflatex and pdfcrop on the 'tikz' output, and "
                  "'makeplots' produces a .dat file for processing with the make-plots command. "
                  "(default: %default)",
                  dest="FORMAT", default="tikzpdf")
parser.add_option("--preamble", metavar="FILE",
                  help="specify a file to be inserted into LaTeX output as a special preamble",
                  dest="PREAMBLE", default=None)
parser.add_option("--minbr", "--br", metavar="BR",
                  help="show decay lines for decays with a branching ratio of > BR, as either a "
                  "fraction or percentage (default: show none)",
                  dest="DECAYS_MINBR", default="1.1")
parser.add_option("--decaystyle", choices=["const", "brwidth", "brcolor", "brwidth+brcolor"], metavar="STYLE",
                  help="drawing style of decay arrows, from const/brwidth. The 'const' style draws "
                  "all decay lines with the same width, 'brwidth' linearly scales the width of the "
                  "decay arrow according to the decay branching ratio. Other modes such as BR-dependent "
                  "colouring may be added later. (default: %default)",
                  dest="DECAYS_STYLE", default="brwidth+brcolor")
parser.add_option("--labels", choices=["none", "merge", "shift"], metavar="MODE",
                  help="treatment of labels for particle IDs, from none/merge/shift. 'none' shows "
                  "no labels at all, 'merge' combines would-be-overlapping labels into a single "
                  "comma-separated list, and 'shift' vertically shifts the clashing labels to avoid "
                  "collisions (default: %default)",
                  dest="PARTICLES_LABELS", default="shift")
parser.add_option("--show-gravitino", action="store_true",
                  help="show the gravitino if available",
                  dest="SHOW_GRAVITINO", default=False)


## Run parser and sort out a few resulting details
opts, args = parser.parse_args()
opts.PARTICLES_LABELS_SHOW = (opts.PARTICLES_LABELS != "none")
opts.PARTICLES_LABELS_MERGE = (opts.PARTICLES_LABELS == "merge")
opts.PARTICLES_LABELS_SHIFT = (opts.PARTICLES_LABELS == "shift")
#
if not opts.SHOW_GRAVITINO:
    del PDETAILS[1000039]
#
if opts.DECAYS_MINBR.endswith("%"):
    opts.DECAYS_MINBR = float(opts.DECAYS_MINBR[:-1]) / 100
else:
    opts.DECAYS_MINBR = float(opts.DECAYS_MINBR)

## Check non-optional arguments
if len(args) != 1:
    parser.print_help()
    sys.exit(1)
opts.INFILE = args[0]


## Choose output file
if opts.OUTFILE is None:
    import os
    o = os.path.basename(opts.INFILE)
    if o == "-":
        o = "out"
    if "." in o:
        o = o[:o.rindex(".")]
    ## Add format-specific suffix
    suffix = ".tex"
    if "pdf" in opts.FORMAT:
        suffix = ".pdf"
    elif opts.FORMAT == "makeplots":
        suffix = ".dat"
    opts.OUTFILE = o + suffix


## Read spectrum file
BLOCKS, DECAYS = None, None
if opts.INFILE == "-":
    intext = sys.stdin.read()
    BLOCKS, DECAYS = pyslha.readSLHA(intext)
elif opts.INFILE.endswith(".isa"):
    BLOCKS, DECAYS = pyslha.readISAWIGFile(opts.INFILE)
else:
    BLOCKS, DECAYS = pyslha.readSLHAFile(opts.INFILE)


## Set mass values in PDETAILS
massblock = BLOCKS["MASS"]
for pid in PDETAILS.keys():
    PDETAILS[pid].mass = abs(massblock.entries[pid])


## Decays
DDETAILS = {}
for pid, detail in sorted(PDETAILS.iteritems()):
    if DECAYS.has_key(pid):
        DDETAILS.setdefault(pid, {})
        xyfrom = (detail.xedges.centre, detail.mass)
        for d in DECAYS[pid].decays:
            if d.br > opts.DECAYS_MINBR:
                for pid2 in d.ids:
                    if PDETAILS.has_key(pid2):
                        xyto = (PDETAILS[pid2].xedges.centre, PDETAILS[pid2].mass)
                        DDETAILS[pid][pid2] = DecayDetails(pid, xyfrom, pid2, xyto, d.br)
    if DDETAILS.has_key(pid) and not DDETAILS[pid]:
        del DDETAILS[pid]


## Labels
PLABELS = []
if opts.PARTICLES_LABELS_SHOW:
    class MultiLabel(object):
        def __init__(self, label=None, x=None, y=None, anchor=None):
            self.labels = [(label, x, y)] or []
            self.anchor = anchor or "l"

        def __len__(self):
            return len(self.labels)

        @property
        def joinedlabel(self):
            return r",\,".join(l[0] for l in self.labels)

        @property
        def avgx(self):
            return sum(l[1] for l in self.labels)/float(len(self))
        @property
        def minx(self):
            return min(l[1] for l in self.labels)/float(len(self))
        @property
        def maxx(self):
            return max(l[1] for l in self.labels)/float(len(self))

        @property
        def avgy(self):
            return sum(l[2] for l in self.labels)/float(len(self))
        @property
        def miny(self):
            return min(l[2] for l in self.labels)/float(len(self))
        @property
        def maxy(self):
            return max(l[2] for l in self.labels)/float(len(self))

        def add(self, label, x, y):
            self.labels.append((label, x, y))
            self.labels = sorted(self.labels, key=lambda l : l[2])
            return self
        def get(self):
            for i in self.labels:
                yield i

    def rel_err(a, b):
        return abs((a-b)/(a+b)/2.0)

    ## Use max mass to work out the height of a text line in mass units
    maxmass = None
    for pid, pdetail in sorted(PDETAILS.iteritems()):
        maxmass = max(pdetail.mass, maxmass)
    text_height_in_mass_units = maxmass/22.0
    ##
    ## Merge colliding labels
    reallabels = []
    for pid, pdetail in sorted(PDETAILS.iteritems()):
        labelx = None
        offset = pdetail.label.offset or 0.2
        anchor = None
        if pdetail.xedges.offset <= 0:
            labelx = pdetail.xedges.left - offset
            anchor = "r"
        else:
            labelx = pdetail.xedges.right + offset
            anchor = "l"
        labely = pdetail.mass
        ## Avoid hitting the 0 mass line/border
        if labely < 0.6*text_height_in_mass_units:
            labely = 0.6*text_height_in_mass_units

        text = pdetail.label.text
        ## Check for collisions
        collision = False
        if opts.PARTICLES_LABELS_SHIFT or opts.PARTICLES_LABELS_MERGE:
            for i, rl in enumerate(reallabels):
                if anchor == rl.anchor and abs(labelx - rl.avgx) < 0.5:
                    import math
                    if labely > rl.miny - text_height_in_mass_units and labely < rl.maxy + text_height_in_mass_units:
                        reallabels[i] = rl.add(text, labelx, labely)
                        collision = True
                        break
        if not collision:
            reallabels.append(MultiLabel(text, labelx, labely, anchor))
    ## Calculate position shifts and fill PLABELS
    for rl in reallabels:
        if len(rl) == 1 or opts.PARTICLES_LABELS_MERGE:
            PLABELS.append(LabelDetails((rl.avgx, rl.avgy), rl.joinedlabel, anchor=rl.anchor))
        else:
            num_gaps = len(rl)-1
            yrange_old = rl.maxy - rl.miny
            yrange_nom = num_gaps * text_height_in_mass_units
            yrange = max(yrange_old, yrange_nom)
            ydiff = yrange - yrange_old
            for i, (t, x, y) in enumerate(rl.get()):
                ydiff_per_line = ydiff/num_gaps
                # TODO: Further improvement using relative or average positions?
                newy = y + (i - num_gaps/2.0) * ydiff_per_line
                PLABELS.append(LabelDetails((x, newy), t, anchor=rl.anchor))


## Function for writing out the generated source
def writeout(out, outfile=opts.OUTFILE):
    f = sys.stdout
    if outfile != "-":
        f = open(outfile, "w")
    f.write(out)
    if f is not sys.stdout:
        f.close()

out = ""
## MAKE-PLOTS FORMAT
if opts.FORMAT == "makeplots":

    ## Write plot header
    out += "# SUSY mass/decay spectrum plot, created by pyslha/slhaplot from %s\n" % opts.INFILE
    out += "# http://pypi.python.org/pypi/pyslha\n"
    out += "\n"
    out += "# BEGIN PLOT\n"
    if opts.PARTICLES_LABELS_MERGE:
        ## Need more space if labels are to be merged horizontally
        out += "XMin=-4\n"
        out += "XMax=20\n"
    else:
        out += "XMin=-3\n"
        out += "XMax=19\n"
    # if opts.LOGSCALE:
    #     out += "LogY=1\n"
    # else:
    #     out += "LogY=0\n"
        out += "YMin=0\n"
    out += "#XCustomTicks=1.0	Higgs	6.0	Sleptons	11.0	Gauginos	16.0	Squarks\n"
    out += "XCustomTicks=-10.0	~\n"
    out += "YLabel=Mass / GeV\n"
    out += "DrawSpecialFirst=1\n"
    out += "# END PLOT\n\n"


    ## Mass lines
    for pid, pdetail in sorted(PDETAILS.iteritems()):
        out += """
# BEGIN HISTOGRAM %s
ErrorBars=1
LineWidth=1pt
LineColor=%s
%f	%f	%e	0
# END HISTOGRAM\n""" % ("pid"+str(pid), pdetail.color,
                      pdetail.xedges.left, pdetail.xedges.right,
                      pdetail.mass)
    out += "\n"


    ## Decay arrows
    for pidfrom, todict in sorted(DDETAILS.iteritems()):
        for pidto, dd in sorted(todict.iteritems()):
            out += r"""
# BEGIN SPECIAL decay_%d_%d
\psset{arrowsize=0.1}
\psline[linestyle=dashed,dash=3px 2px,linecolor=%s]{->}\physicscoor(%f,%f)\physicscoor(%f,%f)
# END SPECIAL
""" % (dd.pidfrom, dd.pidto, dd.color, dd.xyfrom[0], dd.xyfrom[1], dd.xyto[0], dd.xyto[1])


    ## Particle labels
    out += "\n\n"
    out += "# BEGIN SPECIAL labels\n"
    for ld in PLABELS:
        out += r"\rput[%s]\physicscoor(%f,%f){\small %s}" % (ld.anchor, ld.xy[0], ld.xy[1], ld.texlabel) + "\n"
    out += "# END SPECIAL\n"

    ## Write it out
    writeout(out)


## TIKZ FORMAT
if "tikz" in opts.FORMAT:

    ## Comment out the preamble etc. if only the TikZ fragment is wanted
    c = ""
    if opts.FORMAT == "tikzfrag":
        c = "%"

    ## Write LaTeX header
    out += "%% http://pypi.python.org/pypi/pyslha\n\n"
    out += c + "\\documentclass[11pt]{article}\n"
    out += c + "\\usepackage{amsmath,amssymb}\n"
    out += c + "\\usepackage[margin=0cm,paperwidth=15.2cm,paperheight=9.8cm]{geometry}\n"
    out += c + "\\usepackage{tikz}\n"
    out += c + "\\pagestyle{empty}\n"
    out += c + "\n"
    ## Insert user-specified preamble file
    if opts.PREAMBLE is not None:
        out += c + "%% User-supplied preamble\n"
        try:
            fpre = open(opts.PREAMBLE, "r")
            for line in fpre:
                out += c + line
        except:
            sys.stderr.write("Could not read preamble file %s -- fallback to using \\input\n" % opts.PREAMBLE)
            out += c + "\\input{%s}\n" % opts.PREAMBLE.replace(".tex", "")
    else:
        out += c + "%% Default preamble\n"
        out += c + "\\usepackage[osf]{mathpazo}\n"
    #
    out += c + "\n"
    out += c + "\\begin{document}\n"
    out += c + "\\thispagestyle{empty}\n\n"

    ## Get coord space size: horizontal range is fixed by make-plots
    xmin = -3.0
    xmax = 19.0
    if opts.PARTICLES_LABELS_MERGE:
        ## Need more space if labels are to be merged horizontally
        xmin -= 1.0
        xmax += 1.0
    xdiff = xmax - xmin
    XWIDTH = 22.0
    def scalex(x):
        return x * XWIDTH/xdiff

    ASPECTRATIO = 0.7 #0.618
    ydiff = ASPECTRATIO * XWIDTH
    ymin = 0.0
    ymax = ymin + ydiff
    ## Get range of masses needed
    maxmass = max(pd.mass for pid, pd in PDETAILS.iteritems())
    maxdisplaymass = maxmass * 1.1
    if maxdisplaymass % 100 != 0:
        maxdisplaymass = ((maxdisplaymass + 100) // 100) * 100
    yscale = (ymax-ymin)/maxdisplaymass

    ## Write TikZ header
    out += "\\centering\n"
    out += "\\begin{tikzpicture}[scale=0.6]\n"

    out += "  %% y-scalefactor (GeV -> coords) = %e\n\n" % yscale

    ## Draw the plot boundary and y-ticks
    out += "  %% Frame\n"
    out += "  \\draw (%f,%f) rectangle (%f,%f);\n" % (scalex(xmin), ymin, scalex(xmax), ymax)
    out += "  %% y-ticks\n"
    for mtick in xrange(0, int(maxdisplaymass) + 100, 100):
        ytick = mtick * yscale
        out += "  \\draw (%f,%f) node[left] {%d};\n" % (scalex(xmin), ytick, mtick)
        if mtick > 0 and mtick < maxdisplaymass:
            ## The 0.3 needs to be in the plot coords
            out += "  \\draw (%f,%f) -- (%f,%f);\n" % (scalex(xmin+0.3), ytick, scalex(xmin), ytick)
    out += "  \\draw (%f,%f) node[left,rotate=90] {Mass / GeV};\n" % (scalex(xmin-2.0), ymax)

    ## Decay arrows
    if DDETAILS:
        out += "\n  %% Decay arrows\n"
        for pidfrom, todict in sorted(DDETAILS.iteritems()):
            for pidto, dd in sorted(todict.iteritems()):
                out += "  %% decay_%d_%d, BR=%0.1f%%\n" % (dd.pidfrom, dd.pidto, dd.br*100)

                def scalethickness(br):
                    if opts.DECAYS_STYLE == "const":
                        return 0.8
                    elif "brwidth" in opts.DECAYS_STYLE:
                        return 1.0 * br
                    else:
                        raise Exception("Unexpected problem with unknown decay line style option: please contact the PySLHA authors!")

                def scalecolor(br):
                    if opts.DECAYS_STYLE == "const":
                        return None
                    elif "brcolor" in opts.DECAYS_STYLE:
                        return "black!"+str(60*dd.br + 10)
                    else:
                        raise Exception("Unexpected problem with unknown decay line style option: please contact the PySLHA authors!")

                out += "  \\draw[-stealth,line width=%0.2fpt,dashed,color=%s] (%f,%f) -- (%f,%f);\n" % \
                    (scalethickness(dd.br), scalecolor(dd.br) or dd.color,
                     scalex(dd.xyfrom[0]), yscale*dd.xyfrom[1], scalex(dd.xyto[0]), yscale*dd.xyto[1])


    ## Draw mass lines
    if PDETAILS:
        out += "\n  %% Particle lines\n"
        for pid, pdetail in sorted(PDETAILS.iteritems()):
            y = pdetail.mass*yscale
            out += "  %% pid%s\n" % str(pid)
            out += "  \\draw[color=%s,thick] (%f,%f) -- (%f,%f);\n" % \
                (pdetail.color, scalex(pdetail.xedges.left), y, scalex(pdetail.xedges.right), y)

    ## Particle labels
    if PLABELS:
        out += "\n  %% Particle labels\n"
        for ld in PLABELS:
            anchors_pstricks_tikz = { "r" : "left", "l" : "right" }
            out += "  \\draw (%f,%f) node[%s] {\small %s};\n" % \
                (scalex(ld.xy[0]), yscale*ld.xy[1], anchors_pstricks_tikz[ld.anchor], ld.texlabel)

    ## Write TikZ footer
    out += "\end{tikzpicture}\n\n"

    ## Write LaTeX footer
    out += c + "\end{document}\n"


    ## Write output
    if opts.FORMAT != "tikzpdf":
        writeout(out)
    else:
        ## Run LaTeX and pdfcrop
        import tempfile, shutil, subprocess
        tmpdir = tempfile.mkdtemp()
        writeout(out, os.path.join(tmpdir, "mytmp.tex"))
        ok = True
        ## Test for pdflatex
        if ok:
            p = subprocess.Popen(["which", "pdflatex"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            rtn = p.wait()
            if rtn != 0:
                sys.stderr.write("pdflatex could not be found: tikzpdf format mode cannot work\n")
                ok = False
        ## Test for tikz package
        if ok:
            p = subprocess.Popen(["which", "kpsewhich"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            rtn = p.wait()
            if rtn != 0:
                sys.stderr.write("WARNING: kpsewhich could not be found: check for tikz package cannot be run\n")
            else:
                p = subprocess.Popen(["kpsewhich", "tikz.sty"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                rtn = p.wait()
                if rtn != 0:
                    sys.stderr.write("tikz.sty could not be found: tikzpdf format mode cannot work\n")
                    ok = False
        try:
            p = subprocess.Popen(["pdflatex", "\scrollmode\input", "mytmp.tex"],
                                 stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=tmpdir)
            p.wait()
            shutil.copyfile(os.path.join(tmpdir, "mytmp.pdf"), opts.OUTFILE)
        except Exception, e:
            sys.stderr.write("pdflatex could not be run: tikzpdf format mode cannot work\n")
            ok = False
        shutil.rmtree(tmpdir)
        if not ok:
            sys.exit(3)


## UNRECOGNISED FORMAT!
else:
    print "Other formats not currently supported! How did we even get here?!"
    sys.exit(2)
