#!/usr/bin/env python

from __future__ import print_function

import argparse
import boto.sns
import boto.sqs
import json
import os
import Queue
import re
import signal
import socket
import subprocess
import sys
import threading


def print_error(*objs):
    print("ERROR:", *objs, file=sys.stderr)
    sys.exit(1)


def get_sns_connection(args):
    return boto.sns.connect_to_region(
        args.ec2_region,
        aws_access_key_id=args.aws_access_key_id,
        aws_secret_access_key=args.aws_secret_access_key,
    )


def get_sqs_connection(args):
    return boto.sqs.connect_to_region(
        args.ec2_region,
        aws_access_key_id=args.aws_access_key_id,
        aws_secret_access_key=args.aws_secret_access_key,
    )


def get_args():
    default_sqs_queue = 'command-notifier-{0}'.format(socket.getfqdn())
    defaults = {
        'aws-access-key-id': os.environ.get('AWS_ACCESS_KEY_ID'),
        'aws-secret-access-key': os.environ.get('AWS_SECRET_ACCESS_KEY'),
        'command-file': os.environ.get('COMMAND_FILE'),
        'ec2-region': os.environ.get('EC2_REGION'),
        'sns-topic': os.environ.get('SNS_TOPIC', 'command-notifier'),
        'sns-wait-time-seconds': os.environ.get('SNS_WAIT_TIME_SECONDS', 20),
        'sqs-queue': os.environ.get('SQS_QUEUE', default_sqs_queue),
    }
    description = 'cn-subscribe, a tool to run commands published from sqs'
    formatter_class = argparse.RawDescriptionHelpFormatter
    parser = argparse.ArgumentParser(description=description,
                                     epilog='cn-subscribe is pwnage',
                                     formatter_class=formatter_class)
    parser.add_argument('-a', '--aws-access-key-id',
                        type=str,
                        help='AWS Access Key ID',
                        default=defaults['aws-access-key-id'],
                        dest='aws_access_key_id')
    parser.add_argument('-s', '--aws-secret-access-key',
                        type=str,
                        help='AWS Secret Access Key',
                        default=defaults['aws-secret-access-key'],
                        dest='aws_secret_access_key')
    parser.add_argument('-c', '--command-file',
                        type=str,
                        help='Path to json file with whitelisted commands',
                        default=defaults['command-file'],
                        dest='command_file')
    parser.add_argument('-r', '--ec2-region',
                        type=str,
                        help='EC2 Region',
                        default=defaults['ec2-region'],
                        dest='ec2_region')
    parser.add_argument('-t', '--sns-topic',
                        type=str,
                        help='SNS Topic to subscribe to',
                        default=defaults['sns-topic'],
                        dest='sns_topic')
    parser.add_argument('-w', '--sns-wait-time-seconds',
                        type=str,
                        help='SNS Wait Time in Seconds',
                        default=defaults['sns-wait-time-seconds'],
                        dest='sns_wait_time_seconds')
    parser.add_argument('-q', '--sqs-queue',
                        type=str,
                        help='SQS Queue to utilize to',
                        default=defaults['sqs-queue'],
                        dest='sqs_queue')
    return parser.parse_args()


def parse_json(filename):
    """
    From:
    http://www.lifl.fr/~riquetd/parse-a-json-file-with-comments.html
    """
    # Regular expression for comments
    comment_re = re.compile(
        '(^)?[^\S\n]*/(?:\*(.*?)\*/[^\S\n]*|/[^\n]*)($)?',
        re.DOTALL | re.MULTILINE
    )

    with open(filename) as f:
        content = ''.join(f.readlines())
        match = comment_re.search(content)
        while match:
            content = content[:match.start()] + content[match.end():]
            match = comment_re.search(content)
        return json.loads(content)


def get_whitelisted_commands(args):
    whitelisted_commands = None
    if args.command_file:
        if not os.path.isfile(args.command_file):
            print_error('Invalid path to command file')

        try:
            data = parse_json(args.command_file)
            whitelisted_commands = data.get('commands', None)
        except IOError:
            pass

    return whitelisted_commands


def get_topic(sns_conn, args):
    topic = sns_conn.create_topic(args.sns_topic)
    if not topic:
        print_error('Topic not created "{0}"'.format(args.sns_topic))
    return topic


def get_command(message, whitelisted_commands):
    raw_message = None
    try:
        raw_message = json.loads(message.get_body())
        raw_message = raw_message.get('Message')
        raw_message = json.loads(raw_message)
    except:
        return None

    command = raw_message.get('command')
    if not whitelisted_commands:
        return command

    if command not in whitelisted_commands:
        return None

    return whitelisted_commands[command]


def run_command(command):
    io_q = Queue.Queue()

    def stream_watcher(identifier, stream):

        for line in stream:
            io_q.put((identifier, line))

        if not stream.closed:
            stream.close()

    shell = type(command) is not list

    kwargs = {}
    if type(command) is dict:
        kwargs = dict(command)
        command = command['command']

        del kwargs['command']
        if 'shell' in kwargs:
            shell = kwargs['shell']
            del kwargs['shell']
        else:
            shell = type(command) is not list

    proc = subprocess.Popen(command,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE,
                            shell=shell,
                            **kwargs
                            )

    threading.Thread(target=stream_watcher, name='stdout-watcher',
                     args=('STDOUT', proc.stdout)).start()
    threading.Thread(target=stream_watcher, name='stderr-watcher',
                     args=('STDERR', proc.stderr)).start()

    def printer():
        while True:
            try:
                # Block for 1 second.
                item = io_q.get(True, 1)
            except Queue.Empty:
                # No output in either streams for a second. Are we done?
                if proc.poll() is not None:
                    break
            else:
                identifier, line = item
                output = identifier + ':' + line
                print(output.strip())

    threading.Thread(target=printer, name='printer').start()


def main():
    args = get_args()
    sns_conn = get_sns_connection(args)
    sqs_conn = get_sqs_connection(args)

    print('Ensuring topic and queues exist')
    topic = get_topic(sns_conn, args)
    wait_time_seconds = int(args.sns_wait_time_seconds)
    whitelisted_commands = get_whitelisted_commands(args)

    queue = sqs_conn.create_queue(args.sqs_queue)
    queue.set_attribute('ReceiveMessageWaitTimeSeconds', wait_time_seconds)
    queue.set_message_class(boto.sqs.message.RawMessage)
    topic_arn = topic['CreateTopicResponse']['CreateTopicResult']['TopicArn']
    sns_conn.subscribe_sqs_queue(topic_arn, queue)

    def receive_signal(signalnum, frame):
        print("Stopping, waiting for command to finish processing")
        sqs_conn.delete_queue(queue)

    signal.signal(signal.SIGTERM, receive_signal)
    signal.signal(signal.SIGINT, receive_signal)

    print('Ready for commands')
    while True:
        message = None
        try:
            message = queue.read(wait_time_seconds=wait_time_seconds)
        except boto.exception.SQSError as e:
            if e.error_code == 'AWS.SimpleQueueService.NonExistentQueue':
                break

        if not message:
            continue

        try:
            queue.delete_message(message)
        except boto.exception.SQSError as e:
            if e.error_code == 'AWS.SimpleQueueService.NonExistentQueue':
                break

        command = get_command(message, whitelisted_commands)
        if command is None:
            continue

        run_command(command)


if __name__ == '__main__':
    main()
