#!/usr/bin/env python
#Copyright (c) 2012 Yahoo! Inc. All rights reserved.
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. See accompanying LICENSE file.
"""
 Python based ssh multiplexer optimized for map operations
 command line utility.
"""
from __future__ import print_function
import os
import sys
import optparse
import getpass
import sshmap
import hostlists


if __name__ == "__main__":
    parser = optparse.OptionParser()
    parser.add_option("--output_json", dest="output_json",
                      default=False, action="store_true",
                      help = "Output in JSON format")
    parser.add_option("--output_base64", dest="output_base64", default=False,
                      action="store_true", help="Output in base64 format")
    parser.add_option("--summarize_failed", dest = "summarize_failed", default = False, action = "store_true",
                      help = "Print a list of hosts that failed at the end")
    parser.add_option("--aggregate_output", "--collapse", dest = "aggregate_output", default = False,
                      action = "store_true", help = "Aggregate identical list")
    parser.add_option("--only_output", dest = "only_output", default = False, action = "store_true",
                      help = "Only print lines for hosts that return output")
    parser.add_option("--jobs", "-j", dest = "jobs", default = 65, type = "int",
                      help = "Number of parallel commands to execute")
    parser.add_option("--timeout", dest = "timeout", type = "int", default = 0, help = "Timeout, or 0 for no timeout")
    parser.add_option("--sort", dest = "sort", default = False, action = "store_true",
                      help = "Print output sorted in the order listed")
    parser.add_option("--shuffle", dest = "shuffle", default = False, action = "store_true",
                      help = "Shuffle (randomize) the order of hosts")
    parser.add_option("--print_rc", dest = "print_rc", default = False, action = "store_true",
                      help = "Print the return code value")
    parser.add_option("--match", dest = "match", default = None,
                      help = "Only show host output if the string is found in the output")
    parser.add_option("--runscript", "--run_script", dest = "runscript", default = None,
                      help = "Run a script on all hosts.  The command value is the shell to pass the script to on the remote host.")
    parser.add_option("--callback_script", dest = "callback_script", default = None,
                      help = "Script to process the output of each host.  The hostname will be passed as the first argument and the stdin/stderr from the host will be passed as stdin/stderr of the script")
    parser.add_option("--no_status", dest = "show_status", default = True, action = "store_false",
                      help = "Don't show a status count as the command progresses")
    parser.add_option("--sudo", dest="sudo", default=False,
                      action="store_true",
                      help = "Use sudo to run the command as root")
    parser.add_option("--password", dest="password", default=False,
                      action="store_true", help="Prompt for a password")

    (options, args) = parser.parse_args()

    if len(args) == 1 and options.runscript:
        firstline = open(options.runscript).readline().strip()
        if firstline.startswith('#!'):
            command = firstline[2:]
            args.append(command)

    # Default option values
    options.password = None
    options.username = getpass.getuser()
    options.output = True
    # Create our callback pipeline based on the options passed
    callback = [sshmap.callback.summarize_failures]
    if options.match:
        callback.append(sshmap.callback.filter_match)
    if options.output_base64:
        callback.append(sshmap.callback.filter_base64)
    if options.output_json:
        callback.append(sshmap.callback.filter_json)
    if options.callback_script:
        callback.append(sshmap.callback.exec_command)
    else:
        if options.aggregate_output:
            callback.append(sshmap.callback.aggregate_output)
        else:
            callback.append(sshmap.callback.output_prefix_host)
    if options.show_status:
        callback.append(sshmap.callback.status_count)
        # Get the password if the options passed indicate it might be needed
    if options.sudo:
        # Prompt for password, we really need to add a password file option
        try:
            options.password = os.environ['SSHMAP_SUDO_PASSWORD']
        except KeyError:
            options.password = getpass.getpass(
                'Enter sudo password for user ' + getpass.getuser() + ': ')
    elif options.password:
        # Prompt for password, we really need to add a password file option
        try:
            options.password = os.environ['SSHMAP_SUDO_PASSWORD']
        except KeyError:
            options.password = getpass.getpass(
                'Enter password for user ' + getpass.getuser() + ': ')

    command = ' '.join(args[1:])
    host_range = args[0]
    results = sshmap.run(
        host_range, command, username=options.username,
        password=options.password, sudo=options.sudo,
        timeout=options.timeout, script=options.runscript, jobs=options.jobs,
        sort=options.sort,
        shuffle=options.shuffle, output_callback=callback, parms=vars(options)
    )
    if options.aggregate_output:
        aggregate_hosts = results.setting('aggregate_hosts')
        collapsed_output = results.setting('collapsed_output')
        if aggregate_hosts and collapsed_output:
            rows, columns = os.popen('stty size', 'r').read().split()
            for md5 in aggregate_hosts.keys():
                print("=" * (int(columns)-2))
                print(','.join(hostlists.compress(aggregate_hosts[md5])))
                print("-" * (int(columns)-2))
                stdout, stderr = collapsed_output[md5]
                if len(stdout):
                    print(''.join(stdout))
                if len(stderr):
                    print('\n'.join(stderr), file=sys.stderr)
    if options.summarize_failed and 'failures' in results.parm.keys() and \
            len(results.parm['failures']):
        print(
            'SSH Failed to: %s' % hostlists.compress(results.parm['failures'])
        )
