#!/usr/bin/env python
# Copyright 2012 the rootpy developers
# distributed under the terms of the GNU General Public License

from __future__ import division
import rootpy
rootpy.log.basic_config_colorized()
from rootpy.extern import argparse
from rootpy.util.extras import humanize_bytes, print_table
import rootpy
import ROOT
import fnmatch
import os, sys


def find_files(dirs, pattern=None):

    for arg in dirs:
        if os.path.isfile(arg):
            yield arg
            continue
        elif os.path.isdir(arg):
            for root, dirnames, filenames in os.walk(arg):
                if pattern is not None:
                    for filename in fnmatch.filter(filenames, pattern):
                        yield os.path.join(root, filename)
                else:
                    for filename in filenames:
                        yield filename

def make_chain(args, **kwargs):
    
    if hasattr(args, 'staged') and args.staged:
        from rootpy.tree import TreeChain
        files = list(find_files(args.files, args.pattern))
        chain = TreeChain(
                args.tree,
                files,
                verbose=args.verbose,
                **kwargs)
        nfiles = len(files)
    else:
        nfiles = 0
        chain = ROOT.TChain(args.tree, '')
        for filename in find_files(args.files, args.pattern):
            nfiles += 1
            if args.verbose:
                print filename
            chain.Add(filename)
    return chain


class formatter_class(argparse.ArgumentDefaultsHelpFormatter,
                      argparse.RawTextHelpFormatter):
    pass


parser = argparse.ArgumentParser(
        formatter_class=formatter_class)
parser.add_argument('-v', '--verbose', action='store_true', default=False)
parser.add_argument('-V', '--version', action='version',
        version=rootpy.__version__,
        help="show the version number and exit")
parser.add_argument('-d', '--debug', action='store_true', default=False,
        help="show a stack trace")
parser.add_argument('-p', '--pattern', default='*.root*', 
        help="files must match this pattern when searching in directories")

subparsers = parser.add_subparsers()


def entries(args):

    chain = make_chain(args)
    if args.selection is None:
        print "%i total entries" % chain.GetEntries()
    else:
        from rootpy.tree import Cut
        selection = str(Cut(args.selection))
        print "%i entries after selection %s" % (
                chain.GetEntries(selection), selection)


parser_entries = subparsers.add_parser('entries')
parser_entries.add_argument('-s', '--selection', default=None, 
        help="only entries satisfying this cut will be included in total")
parser_entries.add_argument('tree',
        help="name of tree (including path) in each file")
parser_entries.add_argument('files', nargs='+')
parser_entries.set_defaults(op=entries)


def draw(args):
    
    from rootpy.interactive import wait
    wait(True)
    from rootpy.plotting import Canvas
    canvas = Canvas()
    
    def update(*args, **kwargs):
        canvas.Modified()
        canvas.Update()

    chain = make_chain(args,
            onfilechange=[(update, ())])
    if args.selection is None:
        chain.Draw(args.expression, '', args.draw)
    else:
        from rootpy.tree import Cut
        selection = str(Cut(args.selection))
        chain.Draw(args.expression, selection, args.draw)


parser_draw = subparsers.add_parser('draw')
parser_draw.add_argument('-e', '--expression', required=True,
        help="expression to be drawn")
parser_draw.add_argument('-s', '--selection', default=None, 
        help="only entries satisfying this cut will be drawn")
parser_draw.add_argument('-d', '--draw', default='', 
        help="draw options")
parser_draw.add_argument('--staged', action='store_true', default=False, 
        help="update the canvas after each file is drawn")
parser_draw.add_argument('tree',
        help="name of tree (including path) in each file")
parser_draw.add_argument('files', nargs='+')
parser_draw.set_defaults(op=draw)


def sum(args):
    
    from rootpy.interactive import wait
    from rootpy.plotting import Canvas
    from rootpy.io import open as ropen
    
    canvas = Canvas()
    total = None
    for filename in args.files:
        f = ropen(filename)
        h = f.get(args.hist)
        if total is None:
            total = h.Clone()
            total.SetDirectory(0)
        else:
            total += h
        f.close()
    if total is not None:
        total.Draw(args.draw)
    wait(True)


parser_sum = subparsers.add_parser('sum')
parser_sum.add_argument('-d', '--draw', default='', 
        help="draw options")
parser_sum.add_argument('hist',
        help="name of histogram (including path) in each file")
parser_sum.add_argument('files', nargs='+')
parser_sum.set_defaults(op=sum)


def merge(args):
    
    chain = make_chain(args)
    if os.path.exists(args.output):
        sys.exit("Output destination already exists.")
    print "Merging tree %s in %d files into %s ..." % (
            args.tree, len(args.files), args.output)
    chain.Merge(args.output,
            'fast SortBasketsBy%s' % args.sort_by.capitalize())


parser_merge = subparsers.add_parser('merge')
parser_merge.add_argument('-o', '--output', required=True,
        help="output file name")
parser_merge.add_argument('tree',
        help="name of tree (including path) in each file")
parser_merge.add_argument('--sort-by', choices=('offset', 'branch', 'entry'),
        default='offset',
        help="""\
When using 'offset' the baskets are written in
the output file in the same order as in the original file
(i.e. the basket are sorted on their offset in the original
file; Usually this also means that the baskets are sorted
on the index/number of the *last* entry they contain)

When using 'branch' all the baskets of each
individual branches are stored contiguously. This tends to
optimize reading speed when reading a small number (1->5) of
branches, since all their baskets will be clustered together
instead of being spread across the file. However it might
decrease the performance when reading more branches (or the full
entry).

When using 'entry' the baskets with the lowest
starting entry are written first. (i.e. the baskets are
sorted on the index/number of the first entry they contain).
This means that on the file the baskets will be in the order
in which they will be needed when reading the whole tree
sequentially.""")
parser_merge.add_argument('files', nargs='+')
parser_merge.set_defaults(op=merge)


def ls(args):
    from rootpy.io import open
    for i, filename in enumerate(args.files):
        if not os.path.isfile(filename):
            sys.exit("file %s does not exist" % filename)
        with open(filename) as f:
            if len(args.files) > 1:
                if i > 0:
                    print
                print "{0}:".format(filename)
            for dirpath, dirnames, objects in f.walk():
                depth = 0
                if dirpath:
                    depth = dirpath.count('/') + 1
                    print dirpath
                prefix = '   ' * depth
                for name in objects:
                    thing = f.Get(os.path.join(dirpath, name))
                    print "{0} {1}".format(
                            prefix + thing.__class__.__name__,
                            thing.GetName())
                    if not args.showinfo:
                        continue
                    if hasattr(thing, "GetEntries"):
                        print prefix + " Entries: %i" % thing.GetEntries()
                    if hasattr(thing, "GetWeight"):
                        print prefix + " Weight: %e" % thing.GetWeight()
                    if hasattr(thing, "Integral"):
                        print prefix + " Integral: %e" % thing.Integral()
                    if hasattr(thing, "GetListOfFriends"):
                        friends = thing.GetListOfFriends()
                        if friends:
                            print prefix + " Friends:"
                            for friend in friends:
                                print prefix + " {0} in {1}".format(
                                        friend.GetName(), friend.GetTitle())

parser_ls = subparsers.add_parser('ls')
parser_ls.add_argument('-l', '--showinfo', action='store_true',
        default=False,
        help="display object properties")
parser_ls.add_argument('files', nargs='+')
parser_ls.set_defaults(op=ls)


def tree(args):
    from rootpy.io import open
    import re
    import fnmatch
    from operator import itemgetter

    file = open(args.file)
    tree = file.Get(args.tree)

    if args.showtypes:
        totalsize = 0
        totalmatchedsize = 0
        table = []
        for branch in tree.GetListOfBranches():
            typename = branch.GetClassName()
            if args.uncompressed:
                branchsize = branch.GetTotBytes('*')
            else:
                branchsize = branch.GetZipBytes('*')
            totalsize += branchsize
            if args.regex is not None:
                if not re.match(args.regex, branch.GetName()):
                    continue
            if args.glob is not None:
                if not re.search(fnmatch.translate(args.glob),
                        branch.GetName()):
                    continue
            if not typename:
                typename = branch.GetListOfLeaves()[0].GetTypeName()
            table.append((branchsize, (
                humanize_bytes(branchsize), typename, branch.GetName())))
            totalmatchedsize += branchsize
        table.sort(key=itemgetter(0), reverse=True)
        table = [row[1] for row in table]
        print_table(table)
        print "total size %s" % humanize_bytes(totalmatchedsize)
        if totalmatchedsize != totalsize:
            print "%.3g%% of full tree size" % (
                    100. * totalmatchedsize / totalsize,)
    else:
        for branch in tree.GetListOfBranches():
            if args.regex is not None:
                if not re.match(args.regex, branch.GetName()):
                    continue
            if args.glob is not None:
                if not re.search(fnmatch.translate(args.glob),
                        branch.GetName()):
                    continue
            print branch.GetName()
    file.Close()

parser_tree = subparsers.add_parser('tree')
parser_tree.add_argument('-t', '--tree',
        help="Tree name", required=True)
parser_tree.add_argument('-l', '--showtypes', action="store_true",
        help="show branch types/classnames and sizes", default=False)
parser_tree.add_argument('-e','--regex',
        help="only show branches matching this regex", default=None)
parser_tree.add_argument('-g','--glob',
        help="only show branches matching this glob", default=None)
parser_tree.add_argument('-z', '--uncompressed', action="store_true",
        help="show uncompressed branch sizes", default=False)
parser_tree.add_argument('file')
parser_tree.set_defaults(op=tree)

args = parser.parse_args()
try:
    args.op(args)
except Exception as e:
    if args.debug:
        # If in debug mode show full stack trace
        import traceback
        traceback.print_exception(*sys.exc_info())
    else:
        sys.exit("%s: %s" % (e.__class__.__name__, e))
