#!/usr/bin/env python

# Copyright (c) 2014 Anchor Systems

# https://github.com/anchor/elasticsearch-scripts

"""
NAME

    es-shard-count - print a sorted list of nodes in the ElasticSearch
    cluster along with the number of shards each is serving. 

SYNOPSIS

    es-shard-count [HOST[:PORT]]
"""

DEFAULT_HOST = "localhost"
DEFAULT_PORT = 9200

from datetime import datetime
from esadmin import Connection
import logging
import os
import sys

logger = logging.getLogger(__name__)

def usage(argv):
    print >>sys.stderr, ("Usage:\n\t%s [HOST[:PORT]]\n" %
                         os.path.basename(argv[0]))

def main(argv=None):
    if argv is None:
        argv = sys.argv

    host = DEFAULT_HOST
    port = DEFAULT_PORT

    try:
        address = sys.argv[1]
    except IndexError:
        pass
    else:
        if address.find(":") == -1:
            host = address
        else:
            host, port = address.split(":")
            try:
                port = int(port)
            except ValueError:
                usage(argv)
                return 2

    with Connection(host, port) as conn:
        master = conn.master()
        es_state  = conn.get('/_cluster/state')

    esid_node_map = {}
    node_shards = {}
    node_replicas = {}
    nodes = es_state['nodes']
    for esid in nodes:
        name = nodes[esid]['name']
        esid_node_map[esid] = name
	node_shards[name] = []
	node_replicas[name] = []

    topology = {}
    indices = es_state['routing_table']['indices']
    for index in indices:
        if index not in topology:
            topology[index] = {}
        for shard in indices[index]['shards']:
            if shard not in topology[index]:
                topology[index][shard] = {}
                topology[index][shard]['master'] = None
                topology[index][shard]['replicas'] = []
            for replica in indices[index]['shards'][shard]:
                if replica['node'] is None:
                    continue
                node = esid_node_map[replica['node']]
                if replica['primary']:
                    topology[index][shard]['master'] = node
                    node_shards[node].append(index)
                else:
                    topology[index][shard]['replicas'].append(node)
		    node_replicas[node].append(index) 
   
    node_shards_total = dict(map(lambda k: (k, node_replicas[k] + node_shards[k]), 
            node_shards))

    nodes_by_load = map(lambda k: 
            (len(node_shards_total[k]), k), 
            node_shards_total)

    nodes_by_load.sort()
    nodes_by_load.reverse()
    for load, node in nodes_by_load:
        print("%s: %d (%d primary, %d replica)" % (
                node, 
                load,
                len(node_shards[node]),
                len(node_replicas[node]) 
        ))

    return 0

if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s:\t%(message)s')

    sys.exit(main(sys.argv))
