#!/usr/bin/python2 -u
import os
import os.path
import sys
import stat
import signal
import errno
import argparse
import json
import re

KILL_PROCESS_TIMEOUT = 5
KILL_ALL_PROCESSES_TIMEOUT = 5

LOG_LEVEL_ERROR = 1
LOG_LEVEL_WARN = 1
LOG_LEVEL_INFO = 2
LOG_LEVEL_DEBUG = 3

log_level = None


class AlarmException(Exception):
    pass


def error(message):
    if log_level >= LOG_LEVEL_ERROR:
        sys.stderr.write("*** %s\n" % message)


def warn(message):
    if log_level >= LOG_LEVEL_WARN:
        print("*** %s" % message)


def info(message):
    if log_level >= LOG_LEVEL_INFO:
        print("*** %s" % message)


def debug(message):
    if log_level >= LOG_LEVEL_DEBUG:
        print("*** %s" % message)


def ignore_signals_and_raise_keyboard_interrupt(signame):
    signal.signal(signal.SIGTERM, signal.SIG_IGN)
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    raise KeyboardInterrupt(signame)


def raise_alarm_exception():
    raise AlarmException('Alarm')


def listdir(path):
    try:
        result = os.stat(path)
    except OSError:
        return []
    if stat.S_ISDIR(result.st_mode):
        return sorted(os.listdir(path))
    else:
        return []


def is_exe(path):
    try:
        return os.path.isfile(path) and os.access(path, os.X_OK)
    except OSError:
        return False


def import_envvars(clear_existing_environment=True):
    new_env = {}
    for envfile in listdir("/var/lib/container/container_environment"):
        name = os.path.basename(envfile)
        with open("/var/lib/container/container_environment/" + envfile, "r") as f:
            # Text files often end with a trailing newline, which we
            # don't want to include in the env variable value. See
            # https://github.com/phusion/baseimage-docker/pull/49
            value = re.sub('\n\Z', '', f.read())
        new_env[name] = value
    if clear_existing_environment:
        os.environ.clear()
    for name, value in new_env.items():
        os.environ[name] = value


def export_envvars(to_dir=True):
    shell_dump = ""
    for name, value in os.environ.items():
        if to_dir:
            with open("/var/lib/container/container_environment/" + name, "w") as f:
                f.write(value)
        shell_dump += "export " + shquote(name) + "=" + shquote(value) + "\n"
    with open("/var/lib/container/container_environment.sh", "w") as f:
        f.write(shell_dump)
    with open("/var/lib/container/container_environment.json", "w") as f:
        f.write(json.dumps(dict(os.environ)))

_find_unsafe = re.compile(r'[^\w@%+=:,./-]').search


def shquote(s):
    """Return a shell-escaped version of the string *s*."""
    if not s:
        return "''"
    if _find_unsafe(s) is None:
        return s

    # use single quotes, and put single quotes into double quotes
    # the string $'b is then quoted as '$'"'"'b'
    return "'" + s.replace("'", "'\"'\"'") + "'"


def waitpid_reap_other_children(pid):
    done = False
    status = None
    try:
        this_pid, status = os.waitpid(pid, os.WNOHANG)
    except OSError as e:
        if e.errno == errno.ECHILD or e.errno == errno.ESRCH:
            return None
        else:
            raise
    while not done:
        this_pid, status = os.waitpid(-1, 0)
        done = this_pid == pid
    return status


def stop_child_process(name, pid, signo=signal.SIGTERM, time_limit=KILL_PROCESS_TIMEOUT):
    info("Shutting down %s (PID %d)..." % (name, pid))
    try:
        os.kill(pid, signo)
    except OSError:
        pass
    signal.alarm(time_limit)
    try:
        try:
            waitpid_reap_other_children(pid)
        except OSError:
            pass
    except AlarmException:
        warn("%s (PID %d) did not shut down in time. Forcing it to exit." % (name, pid))
        try:
            os.kill(pid, signal.SIGKILL)
        except OSError:
            pass
        try:
            waitpid_reap_other_children(pid)
        except OSError:
            pass
    finally:
        signal.alarm(0)


def run_command_killable(*argv):
    filename = argv[0]
    status = None
    pid = os.spawnvp(os.P_NOWAIT, filename, argv)
    try:
        status = waitpid_reap_other_children(pid)
    except BaseException:
        warn("An error occurred. Aborting.")
        stop_child_process(filename, pid)
        raise
    if status != 0:
        if status is None:
            error("%s exited with unknown status\n" % filename)
        else:
            error("%s failed with status %d\n" % (filename, os.WEXITSTATUS(status)))
        sys.exit(1)


def run_command_killable_and_import_envvars(*argv):
    run_command_killable(*argv)
    import_envvars()
    export_envvars(False)


def kill_all_processes(time_limit):
    info("Killing all processes...")
    try:
        os.kill(-1, signal.SIGTERM)
    except OSError:
        pass
    signal.alarm(time_limit)
    try:
        # Wait until no more child processes exist.
        done = False
        while not done:
            try:
                os.waitpid(-1, 0)
            except OSError as e:
                if e.errno == errno.ECHILD:
                    done = True
                else:
                    raise
    except AlarmException:
        warn("Not all processes have exited in time. Forcing them to exit.")
        try:
            os.kill(-1, signal.SIGKILL)
        except OSError:
            pass
    finally:
        signal.alarm(0)


def run_startup_files(my_initd_path, level=0):
    # Run /etc/my_init.d/*
    for name in listdir(my_initd_path):
        filename = os.path.join(my_initd_path, name)
        if is_exe(filename):
            info("Running %s..." % filename)
            run_command_killable_and_import_envvars(filename)
        elif listdir(filename) and name in os.uname()[1] and level == 0:
            run_startup_files(filename, level + 1)
    if level == 0:
        # Run /etc/rc.local.
        if is_exe("/etc/rc.local"):
            info("Running /etc/rc.local...")
            run_command_killable_and_import_envvars("/etc/rc.local")


def start_supervisord():
    info("Booting supervisord daemon...")
    pid = os.spawnl(os.P_NOWAIT, "/usr/bin/supervisord", "/usr/bin/supervisord")
    info("Supervisord started as PID %d" % pid)
    return pid


def wait_for_supervisord_or_interrupt(pid):
    try:
        status = waitpid_reap_other_children(pid)
        return (True, status)
    except KeyboardInterrupt:
        return (False, None)


def shutdown_supervisord_services():
    debug("Begin shutting down supervisord services...")
    os.system("/usr/bin/supervisorctl stop all")


def install_insecure_key():
    info("Installing insecure SSH key for user root")
    run_command_killable("/usr/local/sbin/enable_insecure_key")


def main(args):
    import_envvars(False)
    export_envvars()

    if not args.skip_supervisord:
        supervisor_pid = start_supervisord()

    if args.enable_insecure_key:
        install_insecure_key()

    if not args.skip_startup_files:
        run_startup_files("/var/lib/container/my_init.d")

    supervisor_exited = False
    exit_code = None
    try:
        exit_status = None
        if len(args.main_command) == 0:
            supervisor_exited, exit_code = wait_for_supervisord_or_interrupt(supervisor_pid)
            if supervisor_exited:
                if exit_code is None:
                    info("Supervisord exited with unknown status")
                    exit_status = 1
                else:
                    exit_status = os.WEXITSTATUS(exit_code)
                    info("Supervisord exited with status %d" % exit_status)
        else:
            info("Running %s..." % " ".join(args.main_command))
            pid = os.spawnvp(os.P_NOWAIT, args.main_command[0], args.main_command)
            try:
                exit_code = waitpid_reap_other_children(pid)
                if exit_code is None:
                    info("%s exited with unknown status." % args.main_command[0])
                    exit_status = 1
                else:
                    exit_status = os.WEXITSTATUS(exit_code)
                    info("%s exited with status %d." % (args.main_command[0], exit_status))
            except KeyboardInterrupt:
                stop_child_process(args.main_command[0], pid)
            except BaseException:
                warn("An error occurred. Aborting.")
                stop_child_process(args.main_command[0], pid)
                raise
        sys.exit(exit_status)
    finally:
        if not args.skip_supervisord:
            shutdown_supervisord_services()
            if not supervisor_exited:
                stop_child_process("supervisor daemon", supervisor_pid)

# Parse options.
parser = argparse.ArgumentParser(description='Initialize the system.')
parser.add_argument('main_command', metavar='MAIN_COMMAND', type=str, nargs='*',
                    help='The main command to run. (default: supervisor)')
parser.add_argument('--enable-insecure-key', dest='enable_insecure_key',
                    action='store_const', const=True, default=False,
                    help='Install the insecure SSH key')
parser.add_argument('--skip-startup-files', dest='skip_startup_files',
                    action='store_const', const=True, default=False,
                    help='Skip running /etc/my_init.d/* and /etc/rc.local')
parser.add_argument('--skip-supervisor', dest='skip_supervisord',
                    action='store_const', const=True, default=False,
                    help='Do not run supervisor services')
parser.add_argument('--kill-all-on-exit', dest='kill_all_on_exit',
                    action='store_const', const=True, default=False,
                    help='Don\'t kill all processes on the system upon exiting')
parser.add_argument('--quiet', dest='log_level',
                    action='store_const', const=LOG_LEVEL_WARN, default=LOG_LEVEL_INFO,
                    help='Only print warnings and errors')
args = parser.parse_args()
log_level = args.log_level

if args.skip_supervisord and len(args.main_command) == 0:
    error("When --skip-supervisor is given, you must also pass a main command.")
    sys.exit(1)

# Run main function.
signal.signal(signal.SIGTERM, lambda signum, frame: ignore_signals_and_raise_keyboard_interrupt('SIGTERM'))
signal.signal(signal.SIGINT, lambda signum, frame: ignore_signals_and_raise_keyboard_interrupt('SIGINT'))
signal.signal(signal.SIGALRM, lambda signum, frame: raise_alarm_exception())
try:
    main(args)
except KeyboardInterrupt:
    warn("Init system aborted.")
    exit(2)
finally:
    if args.kill_all_on_exit:
        kill_all_processes(KILL_ALL_PROCESSES_TIMEOUT)
