#!/usr/bin/python
# -*- coding: utf-8 -*-

import cmd
import ansible.runner
from ansible.color import stringc
from ansible.constants import DIST_MODULE_PATH
from ansible import utils
import ansible.utils.module_docs as module_docs
import sys
import os
import pwd


class AnsibleShell(cmd.Cmd):

    ansible = ansible.runner.Runner()
    groups = ansible.inventory.groups_list().keys()
    hosts = ansible.inventory.groups_list()['all']
    modules = []
    serial = 2

    cwd = ''

    def __init__(self):
        self.intro = 'Welcome to the ansible-shell.\nType help or ? to list commands.\n'
        self.set_prompt()
        self.modules = self.list_modules()
        for module in self.modules:
            setattr(self, 'do_' + module, lambda arg, module=module: self.default(module + ' ' + arg))
            setattr(self, 'help_' + module, lambda module=module: self.helpdefault(module))
        cmd.Cmd.__init__(self)

    def get_names(self):
        return dir(self)

    def cmdloop(self):
        try:
            cmd.Cmd.cmdloop(self)
        except KeyboardInterrupt:
            self.intro = " "
            self.cmdloop()

    def set_prompt(self):
        self.prompt = stringc(pwd.getpwuid(os.getuid())[0] + '@/' + self.cwd, 'green')
        if self.cwd in self.groups:
            self.prompt += stringc(' (' + str(len(self.ansible.inventory.groups_list()[self.cwd])) + ')', 'red')
        self.prompt += '[s:' + stringc(str(self.serial), 'green') + ']'
        self.prompt += '$ '

    def list_modules(self):
        modules = []
        for root, dirs, files in os.walk(DIST_MODULE_PATH):
            for basename in files:
                modules.append(basename)

        return modules

    def default(self, arg, forceshell=False):
        if not self.cwd:
            print "No host found"
            return False

        if arg.split()[0] in self.modules:
            module = arg.split()[0]
            module_args = ' '.join(arg.split()[1:])
        else:
            module = 'shell'
            module_args = arg

        if forceshell is True:
            module = 'shell'
            module_args = arg

        try:
            results = ansible.runner.Runner(
                pattern=self.cwd, forks=self.serial,
                module_name=module, module_args=module_args,
            ).run()
        except:
            return

        if results is None:
            print "No hosts found"
            return False

        for (hostname, result) in results['contacted'].items():
            if 'stderr' in result.keys():
                if not result['stderr']:
                    print "%s\n%s" % (stringc(hostname, 'bright gray'), result['stdout'])
                else:
                    print "%s >>> %s" % (stringc(hostname, 'red'), result['stderr'])
            else:
                if 'failed' not in result.keys():
                    print "%s\n%s" % (stringc(hostname, 'bright gray'), result)
                else:
                    print "%s >>> %s" % (stringc(hostname, 'red'), result)


    def emptyline(self):
        return

    def do_shell(self, arg):
        self.default(arg, True)

    def do_serial(self, arg):
        """Set the number of forks"""
        self.serial = arg
        self.set_prompt()

    def do_cd(self, arg):
        """Change active host/group"""
        if not arg:
            self.cwd = ''
        elif arg == '..':
            try:
                self.cwd = self.ansible.inventory.groups_for_host(self.cwd)[1].name
            except Exception:
                self.cwd = ''
        elif arg == '/':
            self.cwd = ''
        elif arg in self.hosts or arg in self.groups:
            self.cwd = arg
        else:
            print "incorrect path"

        self.set_prompt()

    def do_list(self, arg):
        """List the hosts in the current group"""
        if arg == 'groups':
            items = self.ansible.inventory.list_groups()
        else:
            items = self.ansible.inventory.list_hosts('all' if self.cwd == '' else self.cwd)
        for item in items:
            print item

    def do_EOF(self, args):
        sys.stdout.write('\n')
        return -1

    def do_exit(self, args):
        """Exits from the console"""
        return -1

    def helpdefault(self, module_name):
        if module_name in self.modules:
            in_path = utils.plugins.module_finder.find_plugin(module_name)
            oc, a = ansible.utils.module_docs.get_docstring(in_path)
            print stringc(oc['short_description'], 'bright gray')
            print 'Paremeters:'
            for opt in oc['options'].keys():
                print '  ' + stringc(opt, 'white') + ' ' + oc['options'][opt]['description'][0]

    def complete_cd(self, text, line, begidx, endidx):
        mline = line.partition(' ')[2]
        offs = len(mline) - len(text)

        if self.cwd == '':
            completions = self.hosts + self.groups
        else:
            completions = self.ansible.inventory.list_hosts(self.cwd)

        return [s[offs:] for s in completions if s.startswith(mline)]

    def completedefault(self, text, line, begidx, endidx):
        if line.split()[0] in self.modules:
            mline = line.split(' ')[-1]
            offs = len(mline) - len(text)
            completions = self.module_args(line.split()[0])

            return [s[offs:] + '=' for s in completions if s.startswith(mline)]

    def module_args(self, module_name):
        in_path = utils.plugins.module_finder.find_plugin(module_name)
        oc, a = ansible.utils.module_docs.get_docstring(in_path)
        return oc['options'].keys()


if __name__ == '__main__':
    AnsibleShell().cmdloop()
