#!/usr/bin/env python

# Copyright (c) 2011-2013 Leif Johnson <leif@leifjohnson.net>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

'''A command-line script for plotting data from text files.'''

import argparse
import bz2
import datetime
import glob
import gzip
import itertools
import logging
import numpy as np
import os
import re
import sys

from matplotlib import pyplot as plt
from matplotlib import dates

LEGEND = {
    'ul': 2, 'tl': 2,
    'uc': 9, 'tc': 9,
    'ur': 1, 'tr': 1,
    'cl': 6,
    'cc': 10,
    'cr': 7,
    'll': 3, 'bl': 3,
    'lc': 8, 'bc': 8,
    'lr': 4, 'br': 4,
    }

class ArgParser(argparse.ArgumentParser):
    SANE_DEFAULTS = dict(
        fromfile_prefix_chars='@',
        #conflict_handler='resolve',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    def __init__(self, *args, **kwargs):
        kwargs.update(ArgParser.SANE_DEFAULTS)
        super(ArgParser, self).__init__(*args, **kwargs)

    def convert_arg_line_to_args(self, line):
        '''Remove # comments and blank lines from arg files.'''
        line = line.split('#')[0].rstrip()
        if line:
            if line[0] == '-' and ' ' in line:
                for p in line.split():
                    yield p
            else:
                yield line

FLAGS = ArgParser()

g = FLAGS.add_mutually_exclusive_group()
g.add_argument('-k', '--columns', nargs='+', type=int, metavar='K',
               help='extract data from the Kth space-separated column')
g.add_argument('-m', '--match', nargs='+', default=(r'([-+eE.\d]+)', ), metavar='RE',
               help='extract data points from inputs using RE')

g = FLAGS.add_argument_group('output')
g.add_argument('-A', '--auto', action='store_true',
               help='layout plot automatically')
g.add_argument('-D', '--dpi', type=int, metavar='N',
               help='save figure with dpi N')
g.add_argument('-S', '--figsize', metavar='W,H',
               help='save figure of size W x H inches')
g.add_argument('-o', '--output', metavar='FILE',
               help='save to FILE instead of displaying on screen')

g = FLAGS.add_argument_group('data')
g.add_argument('-b', '--batch', type=int, metavar='N',
               help='batch data into groups of N points and plot mean + std')
g.add_argument('-e', '--every', type=int, metavar='N',
               help='restrict plot to show only every Nth data point')
g.add_argument('-f', '--fill-error', action='store_true',
               help='display vertical error regions as a filled polygon')
g.add_argument('-j', '--jitter', type=float, metavar='R',
               help='add N(0, R) jitter to x values')
g.add_argument('-s', '--smooth', type=int, metavar='N',
               help='smooth data using N-sample rectangular window')
g.add_argument('--hline', type=float, default=(), nargs='+', metavar='Y',
               help='draw horizontal lines at Y')
g.add_argument('--vline', type=float, default=(), nargs='+', metavar='X',
               help='draw vertical lines at X')

g = FLAGS.add_argument_group('series')
g.add_argument('-a', '--alpha', type=float, default=0.9, metavar='N',
               help='plot series with alpha N')
g.add_argument('-c', '--colors', nargs='+', default=tuple('krcbmgy'), metavar='C',
               help='cycle through the given colors')
g.add_argument('-p', '--points', nargs='+', default=('o-', ), metavar='S',
               help='cycle through the given line/point styles')
g.add_argument('-n', '--names', nargs='+', default=(), metavar='L',
               help='use these names in the legend')

g = FLAGS.add_argument_group('axes')
g.add_argument('-g', '--grid', action='store_true',
               help='include a grid')
g.add_argument('-d', '--dates', metavar='FMT',
               help='parse dates from x data using FMT')
g.add_argument('-L', '--legend', choices=tuple(sorted(LEGEND.keys())),
               help='include a legend (None)')
g.add_argument('-l', '--log', choices=('x', 'y', 'xy'),
               help='use a log scale on the specified axes')
g.add_argument('-t', '--title', metavar='S',
               help='use S as the plot title')
g.add_argument('-x', '--xlabel', metavar='S',
               help='use S as the label for the x-axis')
g.add_argument('-y', '--ylabel', metavar='S',
               help='use S as the label for the y-axis')
g.add_argument('-X', '--xlim', metavar='A,B',
               help='use (A,B) as the range for the x-axis')
g.add_argument('-Y', '--ylim', metavar='A,B',
               help='use (A,B) as the range for the y-axis')

FLAGS.add_argument('input', metavar='PATTERN', nargs=argparse.REMAINDER,
                   help='extract data from files matching PATTERN')


def extract_columns(data, columns, x, y, ex, ey):
    '''Pull specific column values out to plot.'''
    if len(columns) == 1:
        y.append(float(data[columns[0]]))
    if len(columns) == 2:
        x.append(float(data[columns[0]]))
        y.append(float(data[columns[1]]))
    if len(columns) == 3:
        x.append(float(data[columns[0]]))
        y.append(float(data[columns[1]]))
        ey.append(float(data[columns[2]]))
    if len(columns) == 4:
        x.append(float(data[columns[0]]))
        y.append(float(data[columns[1]]))
        ey.append(float(data[columns[2]]))
        ex.append(float(data[columns[3]]))


def extract_groupdict(g, x, y, ex, ey):
    '''We've matched a line with named groups. Extract data from them.'''
    logging.debug('group dict: %r', g)

    if 'x' in g:
        while len(x) < len(y):
            x.append(None)
        x.append(float(g['x']))

    y.append(float(g['y']))

    if 'ey' in g:
        while len(ey) < len(y) - 1:
            ey.append(None)
        ey.append(float(g['ey']))

    if 'ex' in g:
        while len(ex) < len(x) - 1:
            ex.append(None)
        ex.append(float(g['ex']))


def extract_groups(g, x, y, ex, ey):
    logging.debug('group matches: %r', g)
    if len(g) > 3:
        FLAGS.error('unnamed --match cannot match more than 3 values')
    elif len(g) == 3:
        while len(x) < len(y):
            x.append(None)
            ey.append(None)
        x.append(float(g[0]))
        y.append(float(g[1]))
        ey.append(float(g[2]))
    elif len(g) == 2:
        while len(x) < len(y):
            x.append(None)
            ey.append(None)
        x.append(float(g[0]))
        y.append(float(g[1]))
    elif len(g) == 1:
        y.append(float(g[0]))


def search_line(line, regex, columns, *series):
    '''Search an input line for groups matching the given regex.

    Extracted data will be added to the mutable series sequences.
    '''
    if not regex:
        return extract_columns(line.split(), columns, *series)

    m = regex.search(line)
    if not m:
        return

    g = m.groupdict()
    if g:
        return extract_groupdict(g, *series)

    extract_groups(m.groups(), *series)


def open_inputs(inputs):
    '''Given input pattern arguments, open up matching files.'''
    for i, pattern in enumerate(inputs or '-'):
        if pattern == '-':
            yield '-', sys.stdin
            continue
        for filename in glob.glob(pattern):
            if filename.endswith('.gz'):
                handle = gzip.open(os.path.expanduser(filename))
            elif filename.endswith('.bz2'):
                handle = bz2.BZ2File(os.path.expanduser(filename))
            else:
                handle = open(os.path.expanduser(filename))
            yield filename, handle


def compile_regex(regex):
    '''Compile a regular expression pattern.'''
    logging.info('compiling regular expression %r', regex)
    try:
        return re.compile(regex)
    except:
        logging.critical('cannot compile regular expression %r', regex)
        sys.exit(-2)


def make_axes(args):
    '''Create an axes object to hold our plots.'''
    ax = plt.subplot(111)
    ax.xaxis.tick_bottom()
    ax.yaxis.tick_left()
    for loc, spine in ax.spines.iteritems():
        if loc in 'left bottom':
            spine.set_position(('outward', 6))
        elif loc in 'right top':
            spine.set_color('none')
    if args.log and 'x' in args.log:
        ax.set_xscale('log')
    if args.log and 'y' in args.log:
        ax.set_yscale('log')
    return ax


def format_axes(ax, args):
    '''Format our plotting axes using the script options.'''
    logging.debug('using legend: %s' % args.legend)
    loc = LEGEND.get(args.legend)
    if loc is not None:
        ax.legend(loc=loc)

    logging.debug('using grid: %s' % args.grid)
    ax.grid(args.grid)

    if args.title:
        logging.debug('using title: %r', args.title)
        ax.set_title(args.title)
    if args.xlabel:
        logging.debug('using x label: %r', args.xlabel)
        ax.set_xlabel(args.xlabel)
    if args.xlim:
        logging.debug('using x limit: %r', args.xlim)
        ax.set_xlim(eval(args.xlim))
    if args.ylabel:
        logging.debug('using y label: %r', args.ylabel)
        ax.set_ylabel(args.ylabel)
    if args.ylim:
        logging.debug('using y limit: %r', args.ylim)
        ax.set_ylim(eval(args.ylim))
    if args.dates:
        logging.debug('using dates from %r on x axis', args.dates)
        loc = dates.AutoDateLocator()
        ax.xaxis.set_major_locator(loc)
        ax.xaxis.set_major_formatter(dates.AutoDateFormatter(loc))


def main(args):
    ax = make_axes(args)

    colors = itertools.cycle(args.colors)
    points = itertools.cycle(args.points)
    limits = dict(xmin=0, xmax=0, ymin=0, ymax=0)

    def plot(label, x, y, ex, ey):
        if args.dates:
            conv = lambda z: datetime.datetime.strptime(str(z), args.dates)
            if args.dates == '%s':
                conv = datetime.datetime.fromtimestamp
            x = [dates.date2num(conv(z)) for z in x]

        if args.smooth:
            y = np.convolve(y, [1. / args.smooth] * args.smooth, 'same')

        if args.batch:
            n = args.batch
            count = int(np.ceil(float(len(y)) / n))
            batches = [y[i * n:(i + 1) * n] for i in range(count)]
            means = [np.array(b).mean() for b in batches]
            stds = [np.array(b).std() for b in batches]
            y, ey = means, stds

        x = x or list(range(len(y)))

        if args.every:
            x = x[::args.every]
            y = y[::args.every]
            ex = ex[::args.every]
            ey = ey[::args.every]

        if args.jitter:
            x += args.jitter * np.random.randn(len(x)) * x

        color = next(colors)

        if x:
            xlo, xhi = min(x), max(x)
            if limits['xmin'] is None:
                limits['xmin'] = xlo
                limits['xmax'] = xhi
            else:
                limits['xmin'] = min(limits['xmin'], xlo)
                limits['xmax'] = max(limits['xmax'], xhi)

        if y:
            ylo, yhi = min(y), max(y)
            if limits['ymin'] is None:
                limits['ymin'] = ylo
                limits['ymax'] = yhi
            else:
                limits['ymin'] = min(limits['ymin'], ylo)
                limits['ymax'] = max(limits['ymax'], yhi)

        if len(ex):
            ex = np.asarray(ex)
            ey = np.asarray(ey)
            ax.errorbar(x, y, xerr=ex, yerr=ey, color=color, aa=True)
            return

        if len(ey):
            ey = np.asarray(ey)
            if args.fill_error:
                ax.fill_between(x, y + ey, y - ey, color=color, alpha=0.3, linewidth=0, antialiased=True)
            else:
                ax.errorbar(x, y, yerr=ey, color=color, aa=True)

        ax.plot(x, y, next(points), c=color, markeredgecolor=color, label=label, alpha=args.alpha, aa=True)

    regexs = itertools.cycle(compile_regex(r) for r in args.match)

    inputs = list(open_inputs(args.input))
    if len(inputs) == 1 and 1 < len(args.match):
        inputs = [next(open_inputs(args.input)) for _ in args.match]

    for i, (path, handle) in enumerate(inputs):
        logging.info('reading data from %s', path)
        regex = next(regexs)
        series = [], [], [], []
        for l, line in enumerate(handle):
            try:
                search_line(line, regex, args.columns, *series)
            except:
                logging.exception('error extracting data from %s:%d %r', path, l, line.rstrip())
        label = os.path.splitext(os.path.basename(path))[0]
        if len(args.names) > i:
            label = args.names[i]
        elif len(args.match) > i:
            label = match.pattern
        plot(label, *series)

    for y, c in zip(args.hline, itertools.cycle(args.colors)):
        ax.hlines(y, limits['xmin'], limits['xmax'], c, 'dashed')
    for x, c in zip(args.vline, itertools.cycle(args.colors)):
        ax.vlines(x, limits['ymin'], limits['ymax'], c, 'dashed')

    format_axes(ax, args)

    if args.figsize:
        plt.gcf().set_size_inches(*eval(args.figsize))

    if args.auto:
        plt.tight_layout()

    if args.output:
        logging.info('%s: saving plot', args.output)
        return plt.savefig(os.path.expanduser(args.output), dpi=args.dpi)

    def on_close(event=None):
        sys.exit()

    def on_key(event):
        if event.key == 'escape':
            sys.exit()

    try:
        plt.connect('close_event', on_close)
        plt.gcf().canvas.mpl_connect('key_press_event', on_key)
    except ValueError:
        pass

    try:
        plt.show()
    except KeyboardInterrupt:
        quit()


if __name__ == '__main__':
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.INFO,
        format='%(levelname).1s %(asctime)s %(message)s')
    main(FLAGS.parse_args())
