#!/usr/bin/env python

import argparse
import json
import logging
import re
import select
import socket
import sys

from colorama import Fore, Style


def hilite(regex, msg):
    return re.sub(r'(%s)' % regex, Style.BRIGHT +
                  Fore.GREEN +
                  r'\1' +
                  Fore.RESET +
                  Style.RESET_ALL, msg)


def matched(filters, fields, andd, colored):
    found = 0
    for flt in filters:
        fk, kv = flt.split('=', 1)
        for path in filter(lambda x: fk == x and re.search(kv, fields[x]), fields):
            found += 1
            if colored:
                fields[path] = hilite(kv, fields[path])
    if andd:
        if len(filters) == found:
            return fields
    else:
        if found > 0:
            return fields
        else:
            return None


def stringify(obj):
    if isinstance(obj, str):
        return obj
    else:
        return repr(obj)


def to_path(data, prefix=""):
    if isinstance(data, dict):
        for k, v in data.items():
            if prefix:
                real_prefix = ".".join((prefix, k))
            else:
                real_prefix = k

            for line in to_path(v, real_prefix):
                yield line
    elif isinstance(data, list):
        yield (prefix, ",".join([stringify(x) for x in data]))
    else:
        yield (prefix, str(data))


def connect(host, port):
    logging.warning("Connecting to %s:%d" % (host, port))
    try:
        sock = socket.socket()
        sock.connect((host, port))
        return sock
    except socket.error:
        logging.error("ERROR: Could not connect to %s:%d" % (host, port))
        return None


logging.basicConfig(format="%(message)s")

parser = argparse.ArgumentParser(description="Tail logstash tcp output")
parser.add_argument("-H", "--host",
                    metavar="HOST",
                    dest="hosts",
                    default=[],
                    action="append",
                    help="Logstash host(s) (multiple accepted)")
parser.add_argument("-p", "--port",
                    required=True,
                    type=int,
                    help="Logstash TCP output port")
parser.add_argument("--and",
                    dest="andd",
                    action="store_true",
                    default=False,
                    help="AND multiple filters")
parser.add_argument("--color",
                    action="store_true",
                    default=False,
                    help="Color highlight filtered output [default: %(default)s]")
parser.add_argument("--filter",
                    metavar="FILTER",
                    dest="filters",
                    action="append",
                    help="Define some filters (multiple accepted; default is to `OR' them")
parser.add_argument("--format",
                    dest="fmt",
                    default="%(@timestamp)s %(@source_host)s: %(@message)s",
                    help="Output format (see README for default)")
args = parser.parse_args()

if not args.hosts:
    args.hosts.append("localhost")

clients = filter(None, [connect(host, args.port) for host in args.hosts])

try:
    while clients:
        [readable, trash1, trash2] = select.select(clients, [], [], 30)
        for cxn in readable:
            line = ""
            while not line.endswith("\n"):
                try:
                    line += cxn.recv(1)
                except socket.error:
                    clients.remove(cxn)
                    break

                if line:
                    try:
                        parsed = json.loads(line)
                    except ValueError:
                        continue

                    fields = dict(to_path(parsed))

                    if args.filters:
                        highlighted = matched(args.filters, fields, args.andd,
                                              args.color)
                        if highlighted:
                            print args.fmt % highlighted
                    else:
                        print args.fmt % fields
except KeyboardInterrupt:
    logging.warning("Disconnecting...")
    sys.exit(0)
