#!/usr/bin/python
# -*- python -*-

import sys, os, hepmc, logging, re

__version__ = "0.4.0"
interactivity = """
  The interactive viewer responds to several keystroke and 
  mouse actions.

  Mouse:
    * right-button and drag: rotate view
    * both buttons and up-down drag: change zoom level

  Keystrokes:
    * h: show help message
    * n or <space>: show next event
    * r: re-draw current event
    * l: toggle log-scaling of vector lengths
    * b: toggle beam display
    * p: print a human-readable event dump to the terminal
    * g: write an image file of this event's graph structure
    * ESC: quit
"""
usage = """%prog - visualise HepMC events in 3D and as a graph

USAGE:
  %prog [options] hepmcfile1 [hepmcfile2] ...

EXAMPLES:
  * %prog events.hepmc
  * %prog --anim-time=0 --no-log-lengths events.hepmc
  * mkfifo hepmc.fifo
    my-generator --hepmc-out=hepmc.fifo &
    %prog hepmc.fifo
  * agile-rungen -o- | %prog -

INTRODUCTION:
  %prog is a viewer for high energy events stored in HepMC 
  plain text files (or streams). The default behaviour is 
  to display events in an interactive 3D viewer, but this
  can also be used to dump out the event graph structure 
  as a HepMC plain text dump or (sometimes more usefully) as
  an image (a PDF by default --- interactivity is on the
  TODO list).

INTERACTION:""" \
    + interactivity + \
"""
TODO:
  * time animation properly
  * more docs about keypress actions
  * clickable particles
  * embed in GTK window
  * integrate interactive graph viewer
  * choose dark or light background
  * shut-down gracefully
  * provide extra "on-screen" stats
  * try to use Python HepPDT/HepPID module to provide 
    nicer particle names in both graph and "click" views

AUTHORS:
  hepmcview and its associated hepmc Python module were
  written by Andy Buckley, 2007-2008.

  Thanks to David Grellscheid for adding the original 
  beampipe and particle colouring code, and the authors
  of vpython and pydot for providing nice modules to do 
  the hard work!"""


## Try to import the visual module
try:
    from visual import *
except:
    print "visual module is required and was not found... exiting"
    sys.exit(1)


## Try to import pydot
WITH_GRAPHVIZ = False
try:
    import pydot
    WITH_GRAPHVIZ = True

    def writeGraph(evt):
        global EVTNUM, EVTFILE
        g = pydot.Dot()
        for v in evt.vertices():
            nodeId = str(v.barcode())
            if g.get_node(nodeId) is not None:
                n = pydot.Node(nodeId)
                n.set_color("grey")
                n.set_style("filled")
                g.add_node(n)
        NUM_V_IN = 0
        NUM_V_OUT = 0
        V_IN_GROUP = pydot.Subgraph("IN")
        #V_IN_GROUP.set_color("blue")
        V_OUT_GROUP = pydot.Subgraph("OUT")
        #V_OUT_GROUP.set_color("red")
        g.add_subgraph(V_IN_GROUP)
        g.add_subgraph(V_OUT_GROUP)
        for p in evt.particles():
            vstart = p.production_vertex()
            if vstart:
                startNodeId = str(vstart.barcode())
                startNode = g.get_node(startNodeId)
            else:
                startNodeId = "IN_" + str(NUM_V_IN)
                startNode = pydot.Node(startNodeId)
                startNode.set_color("blue")
                startNode.set_fontcolor("white")
                startNode.set_style("filled")
                #g.add_node(startNode)
                NUM_V_IN += 1
                V_IN_GROUP.add_node(startNode)
            #
            vend = p.end_vertex()
            if vend:
                endNodeId = str(vend.barcode())
                endNode = g.get_node(endNodeId)
            else:
                endNodeId = "OUT_" + str(NUM_V_OUT)
                endNode = pydot.Node(endNodeId)
                endNode.set_color("red")
                endNode.set_fontcolor("white")
                endNode.set_style("filled")
                #g.add_node(endNode)
                NUM_V_OUT += 1
                V_OUT_GROUP.add_node(endNode)
            #
            e = pydot.Edge(startNodeId, endNodeId) 
            e.set_label(str(p.pdg_id()))
            if p.status() == 1:
                e.set_color("red")
            g.add_edge(e)
        #
        GVPROG = "dot"
        IMGFMT = "pdf"
        fname = re.sub(r"\.hepmc.*$", "", EVTFILE)
        subs = {"f":fname, "n":EVTNUM, "p":GVPROG, "i":IMGFMT}
        rawfile = "hepmc-event-%(f)s-%(n)d.%(p)s" % subs
        logging.info("Writing %s event output to '%s'" % (GVPROG, rawfile))
        g.write(rawfile, prog=GVPROG)
        imgfile = "hepmc-event-%(f)s-%(n)d.%(i)s" % subs
        logging.info("Writing %s event output to '%s'" % (IMGFMT.upper(), imgfile))
        g.write(imgfile, format=IMGFMT, prog=GVPROG)

except:
    pass


def sign(x):
    if x == 0: return 1.0
    return float(x)/float(abs(x))


def logify(v):
    #print v
    veclen = sqrt(v.x**2 + v.y**2 + v.z**2)
    logveclen = log(veclen + 1)
    if logveclen == 0.0:
        return vector(v)
    vec = logveclen/veclen * vector(v)
    return vec


def getColor(p):
    pdgid = abs(p.pdg_id())
    # electron
    if pdgid == 11: 
        return color.green
    # muon
    if pdgid == 13: 
        return color.cyan
    # neutrino (invisible)
    if pdgid == 12 or pdgid == 14 or pdgid == 16: 
        return color.white
    # photon
    if pdgid == 22: 
        return color.yellow
    return color.red


def rapidity(p):
    z = p.momentum().z()
    e = p.momentum().e()
    try:
        rap = 0.5 * ( log(e+z) - log(e-z + 1e-20) )
    except:
        rap = 10000000
    return rap


def clearEvent():
    global VISPARTICLES
    for vp in VISPARTICLES:
        vp.visible = 0
    VISPARTICLES = []


def anim_factor(frac):
    return frac**4


def drawEvent(evt, inittime=0.0):
    global opts, EVTNUM, EVTFILE, VISPARTICLES, DISPLAY
    #ptmiss = vector(0, 0, 0)
    #label(pos=(-20,20,0,), box=0, opacity=0, height=50, text=str(evt.event_number()))
    DISPLAY.title = "%s - %d" % (EVTFILE, EVTNUM)

    clearEvent()
    ps = evt.fsParticles()
    for np, p in enumerate(ps):
        vec = vector(p.momentum().x(), p.momentum().y(), p.momentum().z())
        if opts.LOGLENGTHS:
            vec = logify(vec)
        targetvec = 10 * vec
        vp = cylinder(axis=0.0*targetvec, radius=0.2, color=getColor(p))
        vp.targetaxis = targetvec
        #label(pos=vec, text="hello")
        if not (opts.SHOWBEAMS or abs(rapidity(p)) < 2.5):
            vp.visible = 0
        #pdgid = abs(p.pdg_id())
        #if pdgid == 12 or id == 14 or id == 16:
        #    vp.visible = 0
        #    ptmiss += vec
        VISPARTICLES.append(vp)

    ## Animate the collision, unless the anim time is set to zero
    startframe = 0
    if inittime == 0:
        startframe = 1
    time_per_frame = (len(VISPARTICLES)+1) * 1.25e-5
    nframes = int(inittime/time_per_frame) + 1
    ## TODO: do timing properly
    for frame in range(0, nframes+1):
        for vp in VISPARTICLES:
            frac = frame/float(nframes)
            vp.axis = anim_factor(frac) * vp.targetaxis

    #missing = arrow(axis = ptmiss, radius=0.8, color=color.white)
    #visParticles.append(missing)


def clearStage():
    global STAGE
    for obj in STAGE:
        obj.visible = 0
    STAGE = []


def drawStage():
    global STAGE, opts
    ## Make the "stage"
    if opts.SHOW_AXES:
        for ax in [(1,0,0), (0,1,0), (0,0,1)]:
            a = arrow(axis=20*vector(ax), shaftwidth=0.02, color=color.blue)
            STAGE.append(a)

    ## Draw "beamline" rings
    if opts.SHOW_BEAMLINE:
        for zpos in range(-50,55,5):
            r = ring(pos=(0,0,zpos), axis=(0,0,1), 
                     radius=2.0, thickness=0.02, color=color.blue)
            STAGE.append(r)
    ## Draw "acceptance cones"
    #cone(pos=(0,0, 50), axis=(0,0,-50), radius=1.0, color=color.red)
    #cone(pos=(0,0,-50), axis=(0,0, 50), radius=1.0, color=color.red)


def getNextEvent():
    global reader, remaining_hepmc_files, EVTNUM, EVTFILE
    evt = None
    try:
        if reader is None or reader.rdstate() != 0:
            EVTFILE = remaining_hepmc_files.pop(0)
            logging.info("Reading events from '%s'" % EVTFILE)
            reader = hepmc.IO_GenEvent(EVTFILE, "r")
        EVTNUM += 1
        logging.info("Reading next event: %d" % EVTNUM)
        evt = reader.get_next_event()
    except:
        evt = hepmc.GenEvent()
    ## Convert to standard units if supported
    try:
        evt.use_units(GEV, MM)
    except:
        pass
    return evt




if __name__ == "__main__":

    ## Parse options
    from optparse import OptionParser
    parser = OptionParser(usage=usage, version="$prog " + __version__)
    parser.add_option("-l", "--log-lengths", dest="LOGLENGTHS", action="store_true", default=True,
                      help="Show momentum vectors with aesthetically-appealing log-scaled lengths [default]")
    parser.add_option("-L", "--no-log-lengths", dest="LOGLENGTHS", action="store_false", default=True,
                      help="Show momentum vectors with 'true' lengths, which makes for *huge* momentum vectors along z")
    parser.add_option("-b", "--show-beams", dest="SHOWBEAMS", action="store_true", default=True,
                      help="Show momentum vectors in the beam-pipe [default]")
    parser.add_option("-B", "--no-show-beams", dest="SHOWBEAMS", action="store_false", default=True,
                      help="Don't show momentum vectors in the beam-pipe")
    parser.add_option("--anim-time", dest="INIT_TIME", metavar="T", type=float, default=0.5,
                      help="Time taken to show each event in seconds [default=0.5sec]")
    parser.add_option("-Q", "--quiet", help="Suppress normal messages", dest="LOGLEVEL",
                      action="store_const", default=logging.INFO, const=logging.WARNING)
    parser.add_option("-V", "--verbose", help="Add extra debug messages", dest="LOGLEVEL",
                      action="store_const", default=logging.INFO, const=logging.DEBUG)
    opts, args = parser.parse_args()
    logging.basicConfig(level=opts.LOGLEVEL, format="%(message)s")


    ## Read and process events
    HEPMCFILES = args
    if len(HEPMCFILES) < 1:
        logging.error("You must specify at least one HepMC event file... exiting")
        sys.exit(1)
    for evtfile in HEPMCFILES:
        if evtfile != "-" and not os.access(evtfile, os.R_OK):
            logging.error("Can't read HepMC event file %s... exiting" % evtfile)
            sys.exit(1)
    remaining_hepmc_files = HEPMCFILES
    EVTFILE = None
    EVTNUM = 0
    reader = None

    opts.SHOW_BEAMLINE = False
    opts.SHOW_AXES = True
    
    ## Global variables
    VISPARTICLES = []

    ## Draw stage
    STAGE = []
    DISPLAY = display(width=600, height=600)
    DISPLAY.forward = norm(vector(1,-3,-20))
    drawStage()
    logging.info("hepmcview " + __version__)
    logging.info(interactivity)

    ## Get and draw first event
    EVT = getNextEvent()
    logging.info(EVT.summary())
    drawEvent(EVT)
    DISPLAY.autoscale = 0
    DISPLAY.scale *= 1.8
    drawEvent(EVT, opts.INIT_TIME)

    ## Remember what the real colour of the highlighted particle is
    storedp, storedcolor = None, None

    ## User input event loop runs continuously
    while True:
        ## Read keyboard inputs for drawing next event
        if DISPLAY.kb.keys:
            k = DISPLAY.kb.getkey()
            if k == 'h':
                logging.critical(interactivity)
            if k == 'n' or k == ' ':
                logging.debug("Drawing next event")
                EVT = getNextEvent()
                logging.info(EVT.summary())
                drawEvent(EVT, opts.INIT_TIME)
            elif k == 'r':
                logging.debug("Re-drawing event")
                drawEvent(EVT, opts.INIT_TIME)
            elif k == 'l':
                opts.LOGLENGTHS = not opts.LOGLENGTHS
                logging.debug("Re-drawing event with log lengths = %s" % opts.LOGLENGTHS)
                drawStage()
                drawEvent(EVT)
            elif k == 'a':
                opts.SHOW_AXES = not opts.SHOW_AXES
                logging.debug("Re-drawing event with axes display = %s" % opts.SHOW_AXES)
                clearStage()
                drawStage()
                drawEvent(EVT)
            elif k == 'b':
                opts.SHOWBEAMS = not opts.SHOWBEAMS
                logging.debug("Re-drawing event with beam display = %s" % opts.SHOWBEAMS)
                drawEvent(EVT)
            elif k == 'B':
                opts.SHOW_BEAMLINE = not opts.SHOW_BEAMLINE
                logging.debug("Re-drawing event with beamline display = %s" % opts.SHOW_BEAMLINE)
                clearStage()
                drawStage()
                drawEvent(EVT)
            elif k == 'p':
                print EVT.as_str()
            elif k == 'g':
                if WITH_GRAPHVIZ:
                    writeGraph(EVT)
                else:
                    logging.warning("Couldn't import pydot: Graphviz output is not available")
            elif k == 'q':
                DISPLAY.visible = 0
                sys.exit(0)
            elif k == 's':
                stereoopts = ['nostereo', 'redcyan', 'active', 'passive']
                currentindex = stereoopts.index(DISPLAY.stereo)
                DISPLAY.stereo = stereoopts[currentindex+1%len(stereoopts)]

        ## Allow user to click on particles and highlight them in white
        ## (eventually to be used for getting an info box about that particle)
        if DISPLAY.mouse.clicked:
            click = DISPLAY.mouse.getclick()
            if storedp:
                storedp.color = storedcolor
            storedp, storedcolor = DISPLAY.mouse.pick, None
            if storedp:
                storedcolor = storedp.color
                storedp.color = color.white
