#! /usr/bin/env python
# -*- python -*-

__version__ = "0.4.3"
usage = """%prog - visualise HepMC events as a graph

USAGE:
  %prog [options] hepmcfile1 [hepmcfile2] ...

EXAMPLES:
  * %prog events.hepmc

INTRODUCTION:
  %prog is a viewer for high energy events stored in HepMC
  plain text files (or streams). This is a reduced version
  of the mcview 3D viewer, which just produces handy graphs
  of event structure based on the HepMC record.

AUTHORS:
  Andy Buckley <andy.buckley@cern.ch>

TODO:
  * Allow _multiple_ event numbers to be specified for printing
  * Mark hard interaction vertex with purple colouring
  * Exit as soon as the last requested event has been printed
  * Allow disabling/enabling of the PDF and dot outputs separately
"""

import sys, os

## Try to use particle data tables for better particle labelling
try:
    import pypdt
    pdt = pypdt.PDT()
except:
    pdt = None


## Try to import pydot
WITH_GRAPHVIZ = False
try:
    import pydot
    WITH_GRAPHVIZ = True


    def writeGraph(evt):
        global EVTNUM, EVTFILE
        g = pydot.Dot()
        for v in evt.vertices():
            nodeId = "V%d" % abs(v.barcode())
            if g.get_node(nodeId) is not None:
                n = pydot.Node(nodeId)
                n.set_color("grey")
                n.set_style("filled")
                g.add_node(n)
        NUM_V_IN = 0
        NUM_V_OUT = 0
        V_IN_GROUP = pydot.Subgraph("IN")
        V_OUT_GROUP = pydot.Subgraph("OUT")
        g.add_subgraph(V_IN_GROUP)
        g.add_subgraph(V_OUT_GROUP)
        for p in evt.particles():
            ## Identify/create start vertex
            vstart = p.production_vertex()
            if vstart:
                startNodeId = "V%d" % abs(vstart.barcode())
                startNode = g.get_node(startNodeId)
            else:
                startNodeId = "IN_V%d" % NUM_V_IN
                startNode = pydot.Node(startNodeId)
                startNode.set_color("blue")
                startNode.set_fontcolor("white")
                startNode.set_style("filled")
                NUM_V_IN += 1
                V_IN_GROUP.add_node(startNode)
            #
            ## Identify/create end vertex
            vend = p.end_vertex()
            if vend:
                endNodeId = "V%d" % abs(vend.barcode())
                endNode = g.get_node(endNodeId)
            else:
                endNodeId = "OUT_V%d" % NUM_V_OUT
                endNode = pydot.Node(endNodeId)
                if p.status() == 1:
                    endNode.set_color("red")
                else:
                    endNode.set_color("black")
                endNode.set_fontcolor("white")
                endNode.set_style("filled")
                #g.add_node(endNode)
                NUM_V_OUT += 1
                V_OUT_GROUP.add_node(endNode)
            #
            ## Create the edge between the two vertices
            e = pydot.Edge(startNodeId, endNodeId)
            try:
                pname = pdt[p.pdg_id()].name.replace("^", "")
                e.set_label("P%d (%s)" % (p.barcode(), pname))
            except:
                e.set_label("P%d (%d)" % (p.barcode(), p.pdg_id()))
            if p.status() == 1:
                e.set_color("red")
            elif p.status() == 2:
                e.set_color("orange")
            elif p.status() == 3:
                e.set_color("purple")
            elif p.status() == 4:
                e.set_color("blue")
            g.add_edge(e)


        ## Output
        import re, logging
        fname = os.path.basename(re.sub(r"\.hepmc.*$", "", EVTFILE))
        if fname == "-":
            fname = "stdin"
        # GVPROG = "dot"
        # rawfile = "{base}-{n:04d}.{fmt}".format(base=fname, n=EVTNUM, fmt=GVPROG)
        # logging.info("Writing %s event output to '%s'" % (GVPROG, rawfile))
        # g.write(rawfile, prog=GVPROG)

        for f in ("dot", "pdf", "png", "svg"):
            if f not in opts.FORMAT:
                continue
            outfile = "{base}-{n:04d}.{fmt}".format(base=fname, n=EVTNUM, fmt=f)
            logging.info("Writing %s event output to '%s'" % (f.upper(), outfile))
            g.write(outfile, format=f, prog="dot")

except:
    import logging
    logging.error("Problem while rendering event #{n:d}" % EVTNUM)




def getNextEvent():
    global reader, remaining_hepmc_files, EVTNUM, EVTFILE
    evt = None
    try:
        import logging
        if reader is None or reader.rdstate() != 0:
            EVTFILE = remaining_hepmc_files.pop(0)
            logging.info("Reading events from '%s'" % EVTFILE)
            import hepmc
            reader = hepmc.IO_GenEvent(EVTFILE, "r")
        EVTNUM += 1
        logging.info("Reading next event: %d" % EVTNUM)
        evt = reader.get_next_event()
    except:
        pass
    ## Convert to standard units if supported
    try:
        evt.use_units(GEV, MM)
    except:
        pass
    return evt




if __name__ == "__main__":

    ## Parse options
    import logging
    from optparse import OptionParser
    parser = OptionParser(usage=usage, version="$prog " + __version__)
    parser.add_option("-n", "--num", metavar="NUM", help="only write out the specified event",
                      dest="EVENTNUM", type=int, default=1)
    parser.add_option("-f", "--format", metavar="FMT", help="write the event graph in the specified format(s)",
                      dest="FORMAT", default="pdf")
    parser.add_option("-q", "--quiet", help="suppress normal messages", dest="LOGLEVEL",
                      action="store_const", default=logging.INFO, const=logging.WARNING)
    parser.add_option("-v", "--verbose", help="add extra debug messages", dest="LOGLEVEL",
                      action="store_const", default=logging.INFO, const=logging.DEBUG)
    opts, args = parser.parse_args()
    logging.basicConfig(level=opts.LOGLEVEL, format="%(message)s")

    ## Read and process events
    HEPMCFILES = args
    if len(HEPMCFILES) < 1:
        logging.error("You must specify at least one HepMC event file... exiting")
        sys.exit(1)
    for evtfile in HEPMCFILES:
        if evtfile != "-" and not os.access(evtfile, os.R_OK):
            logging.error("Can't read HepMC event file %s... exiting" % evtfile)
            sys.exit(1)
    remaining_hepmc_files = HEPMCFILES
    EVTFILE = None
    EVTNUM = 0
    reader = None

    ## Get and draw events
    while True:
        evt = getNextEvent()
        if evt is None:
            break
        # TODO: Compare to evt.event_number() rather than local EVTNUM counter?
        if opts.EVENTNUM is None or EVTNUM == opts.EVENTNUM:
            logging.info(evt.summary())
            writeGraph(evt)
