#! /usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import os
import traceback
import shutil
import re
import argparse
import time
import yaml
import ConfigParser
import jinja2
import random
import string

from os import path
from subprocess import check_call, check_output, CalledProcessError

VERSION = '0.2.3'
BASEDIR = path.expanduser('~/.augploy')
SECRET_VARS = ['mysql.root']


class AugployError(Exception):
    def __init__(self, type_, msg):
        self.msg = 'augploy %s error: %s' % (type_, msg)

    def __str__(self):
        return self.msg


class RelaxedUndefined(jinja2.Undefined):
    def __getattr__(self, attribute):
        return ''


def log(msg):
    print(msg)


def makedirs(dir):
    if not path.isdir(dir):
        os.makedirs(dir)


def copytree(src, dst, symlinks=False, ignore=None):
    names = os.listdir(src)
    if ignore is not None:
        ignored_names = ignore(src, names)
    else:
        ignored_names = set()

    if not os.path.isdir(dst):
        os.makedirs(dst)

    errors = []
    for name in names:
        if name in ignored_names:
            continue
        src_path = os.path.join(src, name)
        dst_name = os.path.join(dst, name)
        try:
            if symlinks and os.path.islink(src_path):
                linkto = os.readlink(src_path)
                os.symlink(linkto, dst_name)
            elif os.path.isdir(src_path):
                copytree(src_path, dst_name, symlinks, ignore)
            else:
                # Will raise a SpecialFileError for unsupported file types.
                shutil.copy2(src_path, dst_name)
        # Catch the Error from the recursive copytree so that we can continue with other files.
        except shutil.Error as err:
            errors.extend(err.args[0])
        except EnvironmentError as why:
            errors.append((src_path, dst_name, str(why)))
    try:
        shutil.copystat(src, dst)
    except OSError as why:
        if WindowsError is not None and isinstance(why, WindowsError):
            # Copying file access times may fail on Windows.
            pass
        else:
            errors.extend((src, dst, str(why)))
    if errors:
        raise shutil.Error(errors)


def load_yml_file(yml_path):
    with open(yml_path) as yml_file:
        return yaml.load(yml_file)


def dump_yml_file(obj, yml_path):
    with open(yml_path, 'wb') as yml_file:
        yml_file.writelines(['---\n'])
        yml_file.write(yaml.dump(obj, default_flow_style=False))


def deep_in_place_update(d, s):
    for k, v in s.iteritems():
        if type(v) is dict:
            d[k] = deep_in_place_update(d.get(k, {}), v)
        else:
            d[k] = s[k]
    return d


def deep_in_place_merge(obj1, obj2):
    # obj1 precede over obj2.
    deep_in_place_update(obj2, obj1)
    deep_in_place_update(obj1, obj2)


def get_setdefault(d, key, default=None):
        v = d.get(key)
        if v is None:
            if default is None:
                default = {}
            v = d[key] = default
        return v


def is_scp_style(repo_url):
    return re.match(r'(.*@|.*).*:.*', repo_url)


def parse_repo_url(repo_url):
    tmp = repo_url.split('#')
    repo_url = tmp[0]
    repo_revision = 'master'
    if len(tmp) > 1:
        repo_revision = str.join('#', tmp[1:])
    repo_name = repo_url.split(':')[1].replace('.git', '').replace('/', '-')
    return (repo_url, repo_name, repo_revision)


def prepare_repo(repo_url, tmp_path):
    log('prepare repo %s ...' % repo_url)

    # We take this repo_url as local file path if it doesn't matche scp style, we need to copy it to tmp path.
    # If it does match, we need to fetch it from remote, export it to tmp path with specified revision.
    if not is_scp_style(repo_url):
        repo_name = path.basename(repo_url.rstrip('/'))
        tmp_repo_path = path.join(tmp_path, repo_name)
        copytree(repo_url, tmp_repo_path, symlinks=True,
                 ignore=lambda src, names: names if path.basename(src) == '.git' else [])
        return tmp_repo_path

    repos_path = path.join(BASEDIR, 'repos')
    makedirs(repos_path)

    (repo_url, repo_name, repo_revision) = parse_repo_url(repo_url)
    repo_path = path.join(repos_path, repo_name)
    tmp_repo_path = path.join(tmp_path, repo_name)

    if not path.isdir(repo_path):
        log('    clone ...')
        check_output('git clone --mirror %s %s' % (repo_url, repo_path), shell=True)
    else:
        os.chdir(repo_path)
        log('    update ...')
        check_output('git remote update', shell=True)

    makedirs(tmp_repo_path)
    log('    checkout %s ...' % repo_revision)
    os.chdir(repo_path)
    check_output('git archive %s | (cd %s && tar xf -)' % (repo_revision, tmp_repo_path), shell=True)
    return tmp_repo_path


def check_bin_version(ap_repo_path):
    augploy_content = open(path.join(ap_repo_path, 'bin', 'augploy')).read()
    version_matches = re.findall(r"^VERSION\s*=\s*'([\.\d]+)'$", augploy_content, re.MULTILINE)
    version = version_matches[0]

    if version != VERSION:
        raise AugployError('bin', 'this utility is outdated, please run `pip install -U augploy` to upgrade')


def parse_include(config, config_dir):
    if type(config) is dict:
        include_configs = []
        for k, v in config.iteritems():
            if k == 'include':
                if type(v) is not list:
                    raise AugployError('config', 'include path(%s) should be a list' % v)

                for include_path in v:
                    include_path = path.join(config_dir, include_path)
                    if not path.isfile(include_path):
                        raise AugployError('config', 'include path(%s) not found' % include_path)

                    include_config = load_yml_file(include_path)
                    parse_include(include_config, config_dir)
                    include_configs.append(include_config)
            else:
                if type(v) is dict or type(v) is list:
                    parse_include(v, config_dir)

        if len(include_configs) == 0:
            return

        for include_config in include_configs:
            deep_in_place_merge(config, include_config)

        config.pop('include', None)
        return config
    elif type(config) is list:
        include_configs = {}
        for i, v in enumerate(config):
            include_config = parse_include(v, config_dir)
            if include_config is not None:
                include_configs[i] = include_config

        if len(include_configs.keys()) == 0:
            return

        for k, v in include_configs.iteritems():
            include_index = k
            config = config[0:include_index] + list(include_configs[include_index]) + config[include_index + 1:]


def check_secret_vars_access(str_, variable_start_string, variable_end_string):
    ''' Check if this str contains jinja2 expression accessing secret vars. '''
    for var in SECRET_VARS:
        var_keys = var.split('.')
        fake_value = ''.join(random.choice(string.ascii_uppercase + string.digits) for x in range(10))
        fake_var_dict = fake_value
        for i in range(len(var_keys) - 1, -1, -1):
            fake_var_dict = {
                var_keys[i]: fake_var_dict
            }

        env = jinja2.Environment(
            undefined=RelaxedUndefined,
            variable_start_string=variable_start_string, variable_end_string=variable_end_string)
        tpl = env.from_string(str_.decode('utf-8'))
        rendered = tpl.render(fake_var_dict)
        if rendered.find(fake_value) != -1:
            raise AugployError('config', 'accessing of secret var(%s) is denied' % var)


def parse_vars(config, repo_path=None):
    global_vars = get_setdefault(config, 'global_vars')
    if 'repo' in config:
        repo = config['repo']
        repo['path_local'] = repo_path
        repo['path'] = '/srv/%s' % repo['name']
        global_vars['repo'] = repo

    # Change the default delimiter '{{', '}}' to avoid conflict with yaml syntax,
    # so we can refer to item without quotes, therefore, referring as dict/list is possible.
    env = jinja2.Environment('<%', '%>', '<<', '>>', '<#', '#>')
    config_str = yaml.dump(config)
    check_secret_vars_access(config_str, '<<', '>>')
    tpl = env.from_string(config_str)
    rendered = tpl.render(global_vars)
    config = yaml.load(rendered)
    return config


def parse_config(config_dir, config_path, repo_path=None):
    config = load_yml_file(config_path)
    parse_include(config, config_dir)
    config = parse_vars(config, repo_path)
    return config


def sort_deploys(deploys):
    deploy_types = set()
    deploys_dependon_repo = []
    deploys_not_dependon_repo = []
    hosts_to_deploy_repo = set()
    for deploy in deploys:
        deploy_types.add(deploy['type'])
        if 'deploy_repo' in deploy and deploy['deploy_repo'] is True:
            deploys_dependon_repo.append(deploy)

            hosts = deploy['hosts']
            hosts_yaml_str = yaml.dump(hosts)
            # Yeah, this is dirty...
            host_names = re.findall(r"name:\s*(.+)\n", hosts_yaml_str, re.MULTILINE)
            host_names = [host_name.rstrip('}') for host_name in host_names]
            hosts_to_deploy_repo = hosts_to_deploy_repo.union(set(host_names))
        else:
            deploys_not_dependon_repo.append(deploy)

    if len(deploy_types) < len(deploys):
        raise AugployError('config', 'mutiple deploys of same type in one single config file is not supported')

    deploy_of_repo = [{
        'type': 'repo',
        'hosts': [{'name': host} for host in hosts_to_deploy_repo]
    }]
    sorted_deploys = deploys_not_dependon_repo + deploy_of_repo + deploys_dependon_repo
    return sorted_deploys


def gen_common_paly():
    return {
        'name': 'setup common environment',
        'hosts': 'all',
        'roles': ['common']
    }


def gen_repo_play(config):
    repo = config['repo']
    repo_path_local = repo['path_local']
    repo_path = repo['path']
    tasks = []
    play = {
        'name': 'setup repo',
        'hosts': 'repo',
        'roles': [],
        'tasks': tasks
    }

    for engine in repo['engines']:
        engine_type = engine['type']
        role = {
            'role': engine_type
        }
        if 'vars' in engine:
            global_vars = get_setdefault(config, 'global_vars')
            deep_in_place_update(global_vars, engine['vars'])
        play['roles'].append(role)

    tasks.append({
        'name': 'repo | rsync',
        'rsync_repo': 'src=%s/ dest=%s' % (repo_path_local, repo_path),
        'register': 'rsync_repo_result'
    })
    tasks.append({
        'name': 'repo | set permission',
        'file': 'path=%s owner=www-data group=www-data recurse=yes' % repo_path
    })

    templates = repo.get('templates', [])
    for template in templates:
        src = template['src']
        dest = template['dest']

        template_str = open(path.join(repo_path_local, src)).read()
        check_secret_vars_access(template_str, '{{', '}}')

        tasks.append({
            'name': 'repo | templates | from %s to %s' % (src, dest),
            'template': 'src=%s/%s dest=%s/%s' % (repo_path_local, src, repo_path, dest)
        })

    build_steps = repo.get('build_steps', [])
    for step in build_steps:
        tasks.append({
            'name': 'repo | build_steps | %s' % step['name'],
            'shell': 'chdir=%s %s' % (repo_path, step['shell'])
        })

    return play


def parse_inventory_and_vars(group, inventory, group_vars, host_vars, group_chain=None):
    group_name = group['group']
    hosts = group['hosts']
    group_chain = [] if group_chain is None else group_chain
    group_chain.append(group_name)

    if 'vars' in group:
        # Doesn't support multiple deploys of same type in one single config file.
        group_vars[group_name] = group['vars']

    host_names = []
    group_names = []
    for item in hosts:
        if 'name' in item:
            host = item
            host_name = host['name']
            host_names.append(host_name)
            hv = get_setdefault(host_vars, host_name)
            hv_group_chain = hv.get('__groups', [])
            # Used to obtain group chain of this host.
            hv['__groups'] = hv_group_chain + group_chain
            if 'vars' in host:
                host_vars[host_name] = deep_in_place_update(hv, host['vars'])
        else:
            sub_group = item
            sub_group_name = sub_group['group']
            group_names.append(sub_group_name)
            parse_inventory_and_vars(sub_group, inventory, group_vars, host_vars, group_chain)

    if len(host_names) > 0 and len(group_names) > 0:
        raise AugployError('config', 'host lists and group lists are not compatible')

    if len(host_names) > 0:
        inventory.add_section(group_name)
        for host in host_names:
            inventory.set(group_name, host)
    else:
        group_name = '%s:children' % group_name
        inventory.add_section(group_name)
        for group in group_names:
            inventory.set(group_name, group)


def run_playbook(args, ansible_args, tmp_path):
    repo_path = prepare_repo(args.repo, tmp_path) if args.no_repo is False else None
    ap_repo_path = prepare_repo(args.ap_repo, tmp_path)
    check_bin_version(ap_repo_path)

    # Parse config file.
    config_dir = path.join(ap_repo_path, 'merged_configs')
    if repo_path is not None:
        repo_augploy_path = path.join(repo_path, 'augploy')
        if path.isdir(repo_augploy_path):
            copytree(repo_augploy_path, config_dir, symlinks=True)
    copytree(path.join(ap_repo_path, 'configs'), config_dir, symlinks=True)
    config_name = path.basename(args.config_file).rstrip('.yml')
    config_path = path.join(config_dir, args.config_file)
    if not path.isfile(config_path):
        raise AugployError('config', 'config file(%s) not found' % args.config_file)
    config = parse_config(config_dir, config_path, repo_path)

    # Parse deploys to generate inventory, host_vars and playbook.
    deploys = config['deploys']
    sorted_deploys = deploys
    if 'repo' in config:
        sorted_deploys = sort_deploys(deploys)

    inventory = ConfigParser.RawConfigParser(allow_no_value=True)
    group_vars = {}
    host_vars = {}
    plays = [gen_common_paly()]
    for deploy in sorted_deploys:
        deploy_type = deploy['type']
        group = {
            'group': deploy_type,
            'hosts': deploy['hosts']
        }
        if 'vars' in deploy:
            group['vars'] = deploy['vars']
        parse_inventory_and_vars(group, inventory, group_vars, host_vars)

        if deploy['type'] == 'repo':
            plays.append(gen_repo_play(config))
        else:
            play_path = path.join(ap_repo_path, 'plays', '%s.yml' % deploy_type)
            play = load_yml_file(play_path)
            if type(play) is dict:
                plays.append(play)
            else:
                plays = plays + play

    inventory_path = path.join(ap_repo_path, '%s_hosts' % config_name)
    with open(inventory_path, 'wb') as inventory_file:
        inventory.write(inventory_file)

    # Recursively merge global vars, group vars, host vars into host_vars file, in order,
    # otherwise there can't be duplicate key among them, because ansible-playbook's behavior is not recursively merging
    # damn it, this is complicated...
    for host, vars_ in host_vars.iteritems():
        merged_vars = deep_in_place_update({}, config.get('global_vars', {}))
        groups = vars_.pop('__groups')
        for group in groups:
            tmp_group_vars = group_vars.get(group, {})
            if len(tmp_group_vars.keys()) == 0:
                continue
            deep_in_place_update(merged_vars, tmp_group_vars)
        deep_in_place_update(merged_vars, vars_)

        if len(merged_vars.keys()) > 0:
            host_vars_path = path.join(ap_repo_path, 'host_vars', '%s.yml' % host)
            dump_yml_file(merged_vars, host_vars_path)

    playbook_path = path.join(ap_repo_path, '%s.yml' % config_name)
    dump_yml_file(plays, playbook_path)

    # Actually run playbook.
    os.chdir(ap_repo_path)
    ansible_cmd = 'ansible-playbook -u root -i %s %s %s' % (inventory_path, playbook_path, str.join(' ', ansible_args))
    log('run: %s' % ansible_cmd)
    check_call(ansible_cmd, shell=True)


def main(origin_args):
    # Args parsing.
    parser = argparse.ArgumentParser(
        description='augploy - AUGmentum dePLOYment automation tool, powered by ansible',
        epilog='''
                repo url format:
                1. scp style, eg. git@git.augmentum.com.cn:ops/augploy.git#master, '#master' part is optional,
                can be used to specify git revision: branch name or commit id, default is master
                2. local absolute directory path, eg. /home/user/workspace/ops/augploy,
                git revision is not supported in this type.
                ''',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    args_repo_group = parser.add_mutually_exclusive_group()
    args_repo_group.add_argument(
        '-r', '--repo', default=os.getcwd(),
        help='url of the repo to deploy, default is current working directory for quick test')
    args_repo_group.add_argument(
        '-n', '--no_repo', action='store_true',
        help='to omit the defaut value of -r(--repo), in case there is no repo need to deploy')
    parser.add_argument(
        '-R', '--ap_repo', default='git@git.augmentum.com.cn:ops/augploy.git',
        help='url of augploy repo in which ansible playbooks is stored, \
              this option is for dev purpose, otherwise you should use the default value')
    parser.add_argument(
        'config_file',
        help='config file path, relative to augploy directory in the repo to deploy, \
              or relative to root directory in augploy repo')
    parser.add_argument('-V', action='version', version='%%(prog)s %s' % VERSION)

    args, ansible_args = parser.parse_known_args(origin_args)

    # In order to avoid confilcts when run multiple config files at the same time,
    # prepare tmp dir for exporting repos.
    tmp_path = path.join(BASEDIR, str(time.time()))
    makedirs(tmp_path)
    exit_code = 0

    try:
        run_playbook(args, ansible_args, tmp_path)
    except CalledProcessError as err:
        exit_code = err.returncode
    except AugployError as err:
        exit_code = 1
        log(err)
    except Exception:
        exit_code = 1
        traceback.print_exc()
    finally:
        shutil.rmtree(tmp_path)
        sys.exit(exit_code)

if __name__ == '__main__':
    main(sys.argv[1:])
