#! /usr/bin/env python

"""\
Usage: %prog [options] <spcfile> [<spcfile2> ...]

Make a SUSY mass spectrum plot from an SLHA or ISAWIG spectrum file. If the
filename ends with .isa, it will be assumed to be an ISAWIG file, otherwise
it will be assumed to be an SLHA file (for which the normal extension is .spc).

Output is currently rendered via the LaTeX PGF/TikZ graphics package: this may
be obtained as PDF (by default) or as LaTeX source which can be edited or
compiled into any LaTeX-supported form.

TODOs:
  * EPS and PNG output
  * Allow user to provide a file which defines the particle line x-positions, labels, etc.
  * Use verbosity-controlled logging
  * Allow user control over aspect ratio / geometry
  * Use proper distinction between physical, plot-logical, and plot output coords
  * Use scaling to allow the y coordinates to be in units of 100 GeV in TikZ output.
  * Distribute decay arrow start/end positions along mass lines rather than always to/from their centres?
"""

class XEdges(object):
    def __init__(self, left, offset=0.0, width=2.0):
        self.offset = offset
        self.left = left + offset
        self.width = width
    @property
    def right(self):
        return self.left + self.width
    @property
    def centre(self):
        return (self.left + self.right)/2.0


class Label(object):
    def __init__(self, text, offset=None):
        self.text = text
        self.offset = None
    def __str__(self):
        return self.text


## Details classes for representing decided positions in a way independent of output format


class ParticleDetails(object):
    def __init__(self, label, xnom, xoffset, color="black", labelpos="L", mass=None):
        self.label = label
        self.mass = mass
        self.xedges = XEdges(xnom, xoffset)
        self.color = color
        self.labelpos = labelpos


class DecayDetails(object):
    def __init__(self, pidfrom, xyfrom, pidto, xyto, br, color="gray"): #, thickness=1px, label=None):
        self.pidfrom = pidfrom
        self.xyfrom = xyfrom
        self.pidto = pidto
        self.xyto = xyto
        self.br = br
        self.color = color
        #self.label = label


class LabelDetails(object):
    def __init__(self, xy, texlabel, anchor="l", color="black"):
        self.xy = xy
        self.texlabel = texlabel
        ## Add non-TeX-based label rendering via this property, if needed
        self.textlabel = texlabel
        self.anchor = anchor
        self.color = color


import pyslha
import sys, optparse
parser = optparse.OptionParser(usage=__doc__, version=pyslha.__version__)
parser.add_option("-o", "--out", metavar="FILE",
                  help="write output to FILE (only allowed if only one input file is specified!)",
                  dest="OUTFILE", default=None)
parser.add_option("-f", "--format", metavar="FORMAT",
                  choices=["tikz", "tikzfrag", "tikzpdf"],
                  help="format in which to write output. 'tikz' produces LaTeX source using the "
                  "TikZ graphics package to render the plot, 'tikzfrag' produces the same but "
                  "with the LaTeX preamble and document lines commented out to make it directly "
                  "includeable as a code fragment in LaTeX document source, and 'tikzpdf' produces "
                  "a PDF file created by running pdflatex and pdfcrop on the 'tikz' output"
                  "(default: %default)",
                  dest="FORMAT", default="tikzpdf")
parser.add_option("--preamble", metavar="FILE",
                  help="specify a file to be inserted into LaTeX output as a special preamble",
                  dest="PREAMBLE", default=None)
parser.add_option("--minbr", "--br", metavar="BR",
                  help="show decay lines for decays with a branching ratio of > BR, as either a "
                  "fraction or percentage (default: show none)",
                  dest="DECAYS_MINBR", default="1.1")
parser.add_option("--decaystyle", choices=["const", "brwidth", "brcolor", "brwidth+brcolor"], metavar="STYLE",
                  help="drawing style of decay arrows, from const/brwidth. The 'const' style draws "
                  "all decay lines with the same width, 'brwidth' linearly scales the width of the "
                  "decay arrow according to the decay branching ratio. Other modes such as BR-dependent "
                  "colouring may be added later. (default: %default)",
                  dest="DECAYS_STYLE", default="brwidth+brcolor")
parser.add_option("--labels", choices=["none", "merge", "shift"], metavar="MODE",
                  help="treatment of labels for particle IDs, from none/merge/shift. 'none' shows "
                  "no labels at all, 'merge' combines would-be-overlapping labels into a single "
                  "comma-separated list, and 'shift' vertically shifts the clashing labels to avoid "
                  "collisions (default: %default)",
                  dest="PARTICLES_LABELS", default="shift")


## Run parser and sort out a few resulting details
opts, args = parser.parse_args()
opts.PARTICLES_LABELS_SHOW = (opts.PARTICLES_LABELS != "none")
opts.PARTICLES_LABELS_MERGE = (opts.PARTICLES_LABELS == "merge")
opts.PARTICLES_LABELS_SHIFT = (opts.PARTICLES_LABELS == "shift")
#
if opts.DECAYS_MINBR.endswith("%"):
    opts.DECAYS_MINBR = float(opts.DECAYS_MINBR[:-1]) / 100.0
else:
    opts.DECAYS_MINBR = float(opts.DECAYS_MINBR)
# print opts.DECAYS_MINBR


## Check non-optional arguments
INFILES = args
if len(INFILES) == 0:
    parser.print_help()
    sys.exit(1)
if len(INFILES) > 1 and opts.OUTFILE is not None:
    print "Multiple input files specified with --outfile... not a good plan! Exiting for your own good..."
    sys.exit(1)


## Test for external packages
import subprocess
if opts.FORMAT == "tikzpdf":
    ## Test for pdflatex
    p = subprocess.Popen(["which", "pdflatex"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    rtn = p.wait()
    if rtn != 0:
        sys.stderr.write("pdflatex could not be found: tikzpdf format mode cannot work\n")
        sys.exit(3)
    ## Test for tikz package
    p = subprocess.Popen(["which", "kpsewhich"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    rtn = p.wait()
    if rtn != 0:
        sys.stderr.write("WARNING: kpsewhich could not be found: check for tikz package cannot be run\n")
    else:
        p = subprocess.Popen(["kpsewhich", "tikz.sty"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        rtn = p.wait()
        if rtn != 0:
            sys.stderr.write("tikz.sty could not be found: tikzpdf format mode cannot work\n")
            sys.exit(3)


## Loop over input spectrum files
for infile in INFILES:

    ## Choose output file
    outfile = opts.OUTFILE
    if outfile is None:
        import os
        o = os.path.basename(infile)
        if o == "-":
            o = "out"
        if "." in o:
            o = o[:o.rindex(".")]
        ## Add format-specific suffix
        suffix = ".tex"
        if "pdf" in opts.FORMAT:
            suffix = ".pdf"
        # TODO: add EPS and PNG support
        outfile = o + suffix


    ## Info for the user
    print "Plotting %s -> %s" % (infile, outfile)


    ## Read spectrum file
    BLOCKS, DECAYS = None, None
    # print BLOCKS
    if infile == "-":
        intext = sys.stdin.read()
        BLOCKS, DECAYS = pyslha.readSLHA(intext)
    elif infile.endswith(".isa"):
        BLOCKS, DECAYS = pyslha.readISAWIGFile(infile)
    else:
        BLOCKS, DECAYS = pyslha.readSLHAFile(infile)
    # print BLOCKS


    ## Define particle rendering details (may be adapted based on input file, so it *really*
    ## does need to be redefined in each loop over spectrum files!)
    XHIGGS = 0.0
    XSLEPTON = 5.0
    XGAUGINO = 10.0
    XSUSYQCD = 15.0
    PDETAILS = {
        25 : ParticleDetails(Label(r"$h^0$"), XHIGGS, -0.2, color="blue"),
        35 : ParticleDetails(Label(r"$H^0$"), XHIGGS, -0.2, color="blue"),
        36 : ParticleDetails(Label(r"$A^0$"), XHIGGS, -0.2, color="blue"),
        37 : ParticleDetails(Label(r"$H^\pm$"), XHIGGS, 0.2, color="red"),
        1000011 : ParticleDetails(Label(r"$\tilde{\ell}_\text{L}$"), XSLEPTON, -0.2, color="blue"),
        2000011 : ParticleDetails(Label(r"$\tilde{\ell}_\text{R}$"), XSLEPTON, -0.2, color="blue"),
        1000015 : ParticleDetails(Label(r"$\tilde{\tau}_1$"), XSLEPTON, 0.2, color="red"),
        2000015 : ParticleDetails(Label(r"$\tilde{\tau}_2$"), XSLEPTON, 0.2, color="red"),
        1000012 : ParticleDetails(Label(r"$\tilde{\nu}_\text{L}$"), XSLEPTON, -0.2, color="blue"),
        1000016 : ParticleDetails(Label(r"$\tilde{\nu}_\tau$"), XSLEPTON, 0.2, color="red"),
        1000022 : ParticleDetails(Label(r"$\tilde{\chi}_1^0$"), XGAUGINO, -0.2, color="blue"),
        1000023 : ParticleDetails(Label(r"$\tilde{\chi}_2^0$"), XGAUGINO, -0.2, color="blue"),
        1000025 : ParticleDetails(Label(r"$\tilde{\chi}_3^0$"), XGAUGINO, -0.2, color="blue"),
        1000035 : ParticleDetails(Label(r"$\tilde{\chi}_4^0$"), XGAUGINO, -0.2, color="blue"),
        1000024 : ParticleDetails(Label(r"$\tilde{\chi}_1^\pm$"), XGAUGINO, 0.2, color="red"),
        1000037 : ParticleDetails(Label(r"$\tilde{\chi}_2^\pm$"), XGAUGINO, 0.2, color="red"),
        1000039 : ParticleDetails(Label(r"$\tilde{G}$"), XGAUGINO,  0.15, color="black!50!blue!30!green"),
        1000021 : ParticleDetails(Label(r"$\tilde{g}$"), XSUSYQCD, -0.3, color="black!50!blue!30!green"),
        1000001 : ParticleDetails(Label(r"$\tilde{q}_\text{L}$"), XSUSYQCD, -0.1, color="blue"),
        2000001 : ParticleDetails(Label(r"$\tilde{q}_\text{R}$"), XSUSYQCD, -0.1, color="blue"),
        1000005 : ParticleDetails(Label(r"$\tilde{b}_1$"), XSUSYQCD, 0.2, color="black!50!blue!30!green"),
        2000005 : ParticleDetails(Label(r"$\tilde{b}_2$"), XSUSYQCD, 0.2, color="black!50!blue!30!green"),
        1000006 : ParticleDetails(Label(r"$\tilde{t}_1$"), XSUSYQCD, 0.2, color="red"),
        2000006 : ParticleDetails(Label(r"$\tilde{t}_2$"), XSUSYQCD, 0.2, color="red")
    }


    ## Set mass values in PDETAILS
    massblock = BLOCKS["MASS"]
    for pid in PDETAILS.keys():
        if massblock.entries.has_key(pid):
            PDETAILS[pid].mass = abs(massblock.entries[pid])
        else:
            del PDETAILS[pid]


    ## Decays
    DDETAILS = {}
    for pid, detail in sorted(PDETAILS.iteritems()):
        if DECAYS.has_key(pid):
            DDETAILS.setdefault(pid, {})
            xyfrom = (detail.xedges.centre, detail.mass)
            for d in DECAYS[pid].decays:
                if d.br > opts.DECAYS_MINBR:
                    for pid2 in d.ids:
                        if PDETAILS.has_key(pid2):
                            xyto = (PDETAILS[pid2].xedges.centre, PDETAILS[pid2].mass)
                            DDETAILS[pid][pid2] = DecayDetails(pid, xyfrom, pid2, xyto, d.br)
        if DDETAILS.has_key(pid) and not DDETAILS[pid]:
            del DDETAILS[pid]


    ## Labels
    PLABELS = []
    if opts.PARTICLES_LABELS_SHOW:
        class MultiLabel(object):
            def __init__(self, label=None, x=None, y=None, anchor=None):
                self.labels = [(label, x, y)] or []
                self.anchor = anchor or "l"

            def __len__(self):
                return len(self.labels)

            @property
            def joinedlabel(self):
                return r",\,".join(l[0] for l in self.labels)

            @property
            def avgx(self):
                return sum(l[1] for l in self.labels)/float(len(self))
            @property
            def minx(self):
                return min(l[1] for l in self.labels)/float(len(self))
            @property
            def maxx(self):
                return max(l[1] for l in self.labels)/float(len(self))

            @property
            def avgy(self):
                return sum(l[2] for l in self.labels)/float(len(self))
            @property
            def miny(self):
                return min(l[2] for l in self.labels)/float(len(self))
            @property
            def maxy(self):
                return max(l[2] for l in self.labels)/float(len(self))

            def add(self, label, x, y):
                self.labels.append((label, x, y))
                self.labels = sorted(self.labels, key=lambda l : l[2])
                return self
            def get(self):
                for i in self.labels:
                    yield i

        def rel_err(a, b):
            return abs((a-b)/(a+b)/2.0)

        ## Use max mass to work out the height of a text line in mass units
        maxmass = None
        for pid, pdetail in sorted(PDETAILS.iteritems()):
            maxmass = max(pdetail.mass, maxmass)
        text_height_in_mass_units = maxmass/22.0
        ##
        ## Merge colliding labels
        reallabels = []
        for pid, pdetail in sorted(PDETAILS.iteritems()):
            labelx = None
            offset = pdetail.label.offset or 0.2
            anchor = None
            if pdetail.xedges.offset <= 0:
                labelx = pdetail.xedges.left - offset
                anchor = "r"
            else:
                labelx = pdetail.xedges.right + offset
                anchor = "l"
            labely = pdetail.mass
            ## Avoid hitting the 0 mass line/border
            if labely < 0.6*text_height_in_mass_units:
                labely = 0.6*text_height_in_mass_units

            text = pdetail.label.text
            ## Check for collisions
            collision = False
            if opts.PARTICLES_LABELS_SHIFT or opts.PARTICLES_LABELS_MERGE:
                for i, rl in enumerate(reallabels):
                    if anchor == rl.anchor and abs(labelx - rl.avgx) < 0.5:
                        import math
                        if labely > rl.miny - text_height_in_mass_units and labely < rl.maxy + text_height_in_mass_units:
                            reallabels[i] = rl.add(text, labelx, labely)
                            collision = True
                            break
            if not collision:
                reallabels.append(MultiLabel(text, labelx, labely, anchor))
        ## Calculate position shifts and fill PLABELS
        for rl in reallabels:
            if len(rl) == 1 or opts.PARTICLES_LABELS_MERGE:
                PLABELS.append(LabelDetails((rl.avgx, rl.avgy), rl.joinedlabel, anchor=rl.anchor))
            else:
                num_gaps = len(rl)-1
                yrange_old = rl.maxy - rl.miny
                yrange_nom = num_gaps * text_height_in_mass_units
                yrange = max(yrange_old, yrange_nom)
                ydiff = yrange - yrange_old
                for i, (t, x, y) in enumerate(rl.get()):
                    ydiff_per_line = ydiff/num_gaps
                    # TODO: Further improvement using relative or average positions?
                    newy = y + (i - num_gaps/2.0) * ydiff_per_line
                    PLABELS.append(LabelDetails((x, newy), t, anchor=rl.anchor))


    ## Function for writing out the generated source
    def writeout(out, outfile):
        f = sys.stdout
        if outfile != "-":
            f = open(outfile, "w")
        f.write(out)
        if f is not sys.stdout:
            f.close()

    out = ""

    ## TIKZ FORMAT
    if "tikz" in opts.FORMAT:

        ## Comment out the preamble etc. if only the TikZ fragment is wanted
        c = ""
        if opts.FORMAT == "tikzfrag":
            c = "%"

        ## Write LaTeX header
        out += "%% http://pypi.python.org/pypi/pyslha\n\n"
        out += c + "\\documentclass[11pt]{article}\n"
        out += c + "\\usepackage{amsmath,amssymb}\n"
        out += c + "\\usepackage[margin=0cm,paperwidth=15.2cm,paperheight=9.8cm]{geometry}\n"
        out += c + "\\usepackage{tikz}\n"
        out += c + "\\pagestyle{empty}\n"
        out += c + "\n"
        ## Insert user-specified preamble file
        if opts.PREAMBLE is not None:
            out += c + "%% User-supplied preamble\n"
            try:
                fpre = open(opts.PREAMBLE, "r")
                for line in fpre:
                    out += c + line
            except:
                sys.stderr.write("Could not read preamble file %s -- fallback to using \\input\n" % opts.PREAMBLE)
                out += c + "\\input{%s}\n" % opts.PREAMBLE.replace(".tex", "")
        else:
            out += c + "%% Default preamble\n"
            out += c + "\\usepackage[osf]{mathpazo}\n"
        #
        out += c + "\n"
        out += c + "\\begin{document}\n"
        out += c + "\\thispagestyle{empty}\n\n"

        ## Get coord space size: horizontal range is fixed by make-plots
        xmin = -3.0
        xmax = 19.0
        if opts.PARTICLES_LABELS_MERGE:
            ## Need more space if labels are to be merged horizontally
            xmin -= 1.0
            xmax += 1.0
        xdiff = xmax - xmin
        XWIDTH = 22.0
        def scalex(x):
            return x * XWIDTH/xdiff

        ASPECTRATIO = 0.7 #0.618
        ydiff = ASPECTRATIO * XWIDTH
        ymin = 0.0
        ymax = ymin + ydiff
        ## Get range of masses needed
        maxmass = max(pd.mass for pid, pd in PDETAILS.iteritems())
        maxdisplaymass = maxmass * 1.1
        if maxdisplaymass % 100 != 0:
            maxdisplaymass = ((maxdisplaymass + 100) // 100) * 100
        yscale = (ymax-ymin)/maxdisplaymass

        ## Write TikZ header
        out += "\\centering\n"
        out += "\\begin{tikzpicture}[scale=0.6]\n"

        out += "  %% y-scalefactor (GeV -> coords) = %e\n\n" % yscale

        ## Draw the plot boundary and y-ticks
        out += "  %% Frame\n"
        out += "  \\draw (%f,%f) rectangle (%f,%f);\n" % (scalex(xmin), ymin, scalex(xmax), ymax)
        out += "  %% y-ticks\n"
        for mtick in xrange(0, int(maxdisplaymass) + 100, 100):
            ytick = mtick * yscale
            out += "  \\draw (%f,%f) node[left] {%d};\n" % (scalex(xmin), ytick, mtick)
            if mtick > 0 and mtick < maxdisplaymass:
                ## The 0.3 needs to be in the plot coords
                out += "  \\draw (%f,%f) -- (%f,%f);\n" % (scalex(xmin+0.3), ytick, scalex(xmin), ytick)
        out += "  \\draw (%f,%f) node[left,rotate=90] {Mass / GeV};\n" % (scalex(xmin-2.0), ymax)

        ## Decay arrows
        if DDETAILS:
            out += "\n  %% Decay arrows\n"
            for pidfrom, todict in sorted(DDETAILS.iteritems()):
                for pidto, dd in sorted(todict.iteritems()):
                    out += "  %% decay_%d_%d, BR=%0.1f%%\n" % (dd.pidfrom, dd.pidto, dd.br*100)

                    def scalethickness(br):
                        if opts.DECAYS_STYLE == "const":
                            return 0.8
                        elif "brwidth" in opts.DECAYS_STYLE:
                            return 1.0 * br
                        else:
                            raise Exception("Unexpected problem with unknown decay line style option: please contact the PySLHA authors!")

                    def scalecolor(br):
                        if opts.DECAYS_STYLE == "const":
                            return None
                        elif "brcolor" in opts.DECAYS_STYLE:
                            return "black!"+str(60*dd.br + 10)
                        else:
                            raise Exception("Unexpected problem with unknown decay line style option: please contact the PySLHA authors!")

                    out += "  \\draw[-stealth,line width=%0.2fpt,dashed,color=%s] (%f,%f) -- (%f,%f);\n" % \
                        (scalethickness(dd.br), scalecolor(dd.br) or dd.color,
                         scalex(dd.xyfrom[0]), yscale*dd.xyfrom[1], scalex(dd.xyto[0]), yscale*dd.xyto[1])


        ## Draw mass lines
        if PDETAILS:
            out += "\n  %% Particle lines\n"
            for pid, pdetail in sorted(PDETAILS.iteritems()):
                y = pdetail.mass*yscale
                out += "  %% pid%s\n" % str(pid)
                out += "  \\draw[color=%s,thick] (%f,%f) -- (%f,%f);\n" % \
                    (pdetail.color, scalex(pdetail.xedges.left), y, scalex(pdetail.xedges.right), y)

        ## Particle labels
        if PLABELS:
            out += "\n  %% Particle labels\n"
            for ld in PLABELS:
                anchors_pstricks_tikz = { "r" : "left", "l" : "right" }
                out += "  \\draw (%f,%f) node[%s] {\small %s};\n" % \
                    (scalex(ld.xy[0]), yscale*ld.xy[1], anchors_pstricks_tikz[ld.anchor], ld.texlabel)

        ## Write TikZ footer
        out += "\end{tikzpicture}\n\n"

        ## Write LaTeX footer
        out += c + "\end{document}\n"


        ## Write output
        if opts.FORMAT != "tikzpdf":
            writeout(out, outfile)
        else:
            ## Run LaTeX and pdfcrop
            import tempfile, shutil, subprocess
            tmpdir = tempfile.mkdtemp()
            writeout(out, os.path.join(tmpdir, "mytmp.tex"))
            ok = True
            try:
                p = subprocess.Popen(["pdflatex", "\scrollmode\input", "mytmp.tex"],
                                     stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=tmpdir)
                p.wait()
                shutil.copyfile(os.path.join(tmpdir, "mytmp.pdf"), outfile)
            except Exception, e:
                sys.stderr.write("pdflatex could not be run: tikzpdf format mode cannot work\n")
                ok = False
            shutil.rmtree(tmpdir)
            if not ok:
                sys.exit(3)


    ## UNRECOGNISED FORMAT!
    else:
        print "Other formats not currently supported! How did we even get here?!"
        sys.exit(2)
