#!/usr/bin/env python
# coding=utf-8
import argparse
import gzip
import json
import multiprocessing
import sys
from time import sleep

from pynamodb.connection import Connection
from pynamodb.constants import (ITEMS, LAST_EVALUATED_KEY,
                                PROVISIONED_THROUGHPUT, READ_CAPACITY_UNITS,
                                TOTAL)
from pynamodb.throttle import Throttle


parser = argparse.ArgumentParser(
    prog="ddb-dumper",
    description="""DynamoDB Dumper: backup tables out of DynamoDB with ease."""
)
parser.add_argument(
    '-r',
    '--region',
    type=str,
    default="us-east-1",
    help="The region to connect to."
)
parser.add_argument(
    '-o',
    '--host',
    type=str,
    help="The host url to connect to (for use with DynamoDB Local)."
)
parser.add_argument(
    '-s',
    '--total-segments',
    type=int,
    default=None,
    help="The number of segments to scan in parallel (defaults to the number of processors you have)."
)
parser.add_argument(
    '-p',
    '--parallelism',
    type=int,
    default=None,
    help="The number of processes to use (defaults to the number of processors you have)."
)
parser.add_argument(
    '-c',
    '--compress',
    action='store_true',
    help="Whether output files should be compressed with gzip (default off)."
)
parser.add_argument(
    '--capacity-consumption',
    type=float,
    default=0.5,
    help="The amount (between 0.01 and 1.0) of the total read capacity of the table to consume (default 0.5)."
)
parser.add_argument(
    'table_name',
    type=str,
    help="The name of the table to dump."
)


def main(host, region, table_name, total_segments, compress, parallelism, capacity_consumption):
    if not total_segments:
        total_segments = multiprocessing.cpu_count()

    if parallelism is None:
        parallelism = multiprocessing.cpu_count()

    capacity_consumption = max(0.01, capacity_consumption)
    capacity_consumption = min(1.0, capacity_consumption)

    connection = Connection(host=host, region=region)
    desc = connection.describe_table(table_name)
    if desc is None:
        raise NameError("Table does not exist.")
    total_items = desc['ItemCount']

    total_capacity = desc[PROVISIONED_THROUGHPUT][READ_CAPACITY_UNITS]
    capacity_per_process = max(
        1.0,
        (capacity_consumption * total_capacity) / float(parallelism)
    )

    queue = multiprocessing.Queue()
    pool = multiprocessing.Pool(processes=parallelism,
                                initializer=dump_init,
                                initargs=(queue, capacity_per_process))
    for x in xrange(total_segments):
        pool.apply_async(dump, [host, region, table_name, x, total_segments, compress])

    segments_complete = 0
    items_dumped = 0
    while True:
        sleep(1)
        while not queue.empty():
            update = queue.get()
            if update == 'complete':
                segments_complete += 1
            else:
                items_dumped += update

        print "{}/~{} items dumped - {}/{} segments.".format(
            items_dumped,
            total_items,
            segments_complete,
            total_segments,
        )

        if segments_complete == total_segments:
            break

    pool.close()
    pool.join()

    print "Done."


def dump_init(_queue, _capacity):
    proc = multiprocessing.current_process()
    proc.queue = _queue
    proc.capacity = _capacity


def dump(host, region, table_name, segment, total_segments, compress):
    proc = multiprocessing.current_process()
    queue = proc.queue
    capacity = proc.capacity
    connection = Connection(host=host, region=region)

    filename = ".".join([table_name, str(segment), "dump"])
    if compress:
        opener = gzip.GzipFile
        filename += ".gz"
    else:
        opener = open

    throttle = Throttle(capacity)

    with opener(filename, 'w') as output:
        data = None
        last_evaluated_key = None
        while data is None or last_evaluated_key:
            data = connection.scan(
                table_name=table_name,
                segment=segment,
                limit=100,
                total_segments=total_segments,
                exclusive_start_key=last_evaluated_key,
                return_consumed_capacity=TOTAL
            )
            capacity = data.get('ConsumedCapacity', {}).get('CapacityUnits', 0)
            throttle.add_record(capacity)
            items = data.get(ITEMS)

            for item in items:
                output.write(json.dumps(item))
                output.write("\n")
            output.flush()

            queue.put(len(items))
            sleep(0.1)  # Replace with throttling
            last_evaluated_key = data.get(LAST_EVALUATED_KEY)

    queue.put('complete')


if __name__ == '__main__':
    kwargs = dict(parser.parse_args(sys.argv[1:])._get_kwargs())
    main(**kwargs)
