#!/usr/bin/env python
# SSH Multiplexor is a tool to run shell commands on one or more Unix hosts
# Currently xssh.py does not support input from command line for host names
# you have to use an infile, the format of the infile is one host per line
# I plan on adding more features to this, feel free to fork and do the same
# License for this file is the same as for the other stuff in my Git repo
# Apache License Version 2.0 January 2004
# Usage Example:
# xssh.py -i <infile> -c "shell-command" 
# xssh.py -i <infile> -c "shell-command" -t <ssh-timeout> -n <number-of-forks> -o <outfile>

import socket
import sys
import getpass
import paramiko
import argparse
import os

dns_not_ok = []
dns_ok = []
ssh_not_ok = []
children = []

# we need at least 4 args, 1) -i for input file 2) input file name 3) -c
# for command 4) command name


def check_args():
	if len(sys.argv) < 4:
		print "ERROR: insufficient arguments, use -h for help"
		sys.exit(0)

# parsing the arguments
parser = argparse.ArgumentParser(
    description="SSH Multiplexor", epilog="Beta software, use at your own risk", version='0.2')
parser.add_argument('-i', action="store",
                    help="filename for reading hosts from", required=True)
parser.add_argument('-n', action="store", type=int,
                    help="number of processes to fork, defaults to 1")
parser.add_argument('-t', action="store", type=int,
                    help="timeout value when attempting to connect to host, defaults 5")
parser.add_argument('-c', action="store",
                    help="command to run on remote host", required=True)
parser.add_argument('-o', action="store", help="outfile to dump output in")

args = parser.parse_args()
infile = args.i
if not os.path.isfile(infile):
    print "ERROR: file %s does not exist" %(infile)
    sys.exit(1)
    
command = args.c

# by default we timeout at 3 unless -t is provided
if args.t:
	timeout = args.t
else:
	timeout = 3

# by default we fork only once unless user specifies number of forks
if args.n:
	forks = args.n
else:
	forks = 1

# check if outfile is specified
if args.o:
	outfile = args.o
else:
	outfile = None

# purpose is to read the input file line at a time, then check if hostname resolves
# if hostname does not resolve, skip host


def read_file(infile):
	if outfile is not None:
		stdout = sys.stdout
		sys.stdout = open(outfile, 'a', 0)
	with open(infile, 'r') as f:
		for line in f:
			line = line.rstrip()
			try:
				socket.setdefaulttimeout(1)
				socket.getaddrinfo(line, 22)
				dns_ok.append(line)
			except:
				dns_not_ok.append(line)
				print 'ERROR: host %s is not resolving' % (line)

# this routine does the ssh using paramiko


def ssh_to_hosts(hostlist, username, passwd):
	for host in hostlist:
		ssh = paramiko.SSHClient()
		ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
		if outfile is not None:
			stdout = sys.stdout
			sys.stdout = open(outfile, 'a', 0)
		try:
			ssh.connect(host, username=username, password=passwd, timeout=timeout)
			stdin, stdout, stderr = ssh.exec_command(command)
			for line in stdout:
				print line.strip('\n')
				ssh.close()
		except:
			ssh_not_ok.append(host)
			print 'ERROR:ssh failed on', host
		    

# if we are forking more than once, then multiprocess is called to handle the splitting of the list
# i am using numpy which makes it easy to split the array into equal parts
childpid = []


def multiprocess(processestocreate):
    for counter in range(processestocreate):
        pid = os.fork()
        if pid:
            # in parent
            childpid.append(pid)
        else:
            # in child
            ssh_to_hosts(splithostlist[counter], username, passwd)
            sys.exit(0)

#**************************************************************
# Workflow commences here
#**************************************************************

check_args()
passwd = getpass.getpass('Enter password-> ')
username = getpass.getuser()
read_file(infile)

# take the size of the dns_ok list, let's say that is 10, and if we are asked to fork say 2 times, then 
# create sub-lists, each having 5 hosts in it, and fork 2 times
# sizeofeachlist is the size of the sub-list
# splithostlist is a list that contains  sub-lists, as in [[h1,h2],[h3,h4],[h5]]
# processestocreate is computed by getting the length of the splithostslist

if forks > 1:
	hostlist = dns_ok
	sizeofeachlist=len(hostlist)/forks
	splithostlist = [hostlist[x:x+sizeofeachlist] for x in xrange(0, len(hostlist), sizeofeachlist)]
	processestocreate=len(splithostlist)
	multiprocess(processestocreate)
else:
	ssh_to_hosts(dns_ok, username, passwd)

for pid in childpid:
	pid, status = os.waitpid(pid, 0)
	if os.WIFEXITED(status):
		os._exit(0)
