#!/usr/bin/env python
# coding=utf-8
from __future__ import print_function

import argparse
import gzip
import json
import multiprocessing
import sys
from time import sleep

from pynamodb.connection import Connection
from pynamodb.constants import (ITEMS, LAST_EVALUATED_KEY,
                                PROVISIONED_THROUGHPUT, READ_CAPACITY_UNITS,
                                TOTAL)
from pynamodb.throttle import Throttle


CPU_COUNT = multiprocessing.cpu_count()

parser = argparse.ArgumentParser(
    prog="dynamodb-dumper",
    description="""DynamoDB Dumper: backup tables out of DynamoDB with ease."""
)
parser.add_argument(
    '-r',
    '--region',
    type=str,
    default="us-east-1",
    help="The region to connect to."
)
parser.add_argument(
    '-o',
    '--host',
    type=str,
    help="The host url to connect to (for use with DynamoDB Local)."
)
parser.add_argument(
    '-s',
    '--total-segments',
    type=int,
    default=CPU_COUNT,
    help="The number of segments to scan in parallel (defaults to the number of processors you have)."
)
parser.add_argument(
    '-k',
    '--hash-keys',
    type=str,
    nargs='+',
    default=[],
    help="The list of hash keys to dump data for (defaults to all - by setting this, 'Query' will be used as opposed to 'Scan' with parallel segments - thus --total-segments will be overridden)."
)
parser.add_argument(
    '-p',
    '--parallelism',
    type=int,
    default=CPU_COUNT,
    help="The number of processes to use (defaults to the number of processors you have)."
)
parser.add_argument(
    '-c',
    '--compress',
    action='store_true',
    help="Whether output files should be compressed with gzip (default off)."
)
parser.add_argument(
    '--capacity-consumption',
    type=float,
    default=0.5,
    help="The amount (between 0.01 and 1.0) of the total read capacity of the table to consume (default 0.5)."
)
parser.add_argument(
    'table_name',
    type=str,
    help="The name of the table to dump."
)


def main(host, region, table_name, total_segments, hash_keys, compress, parallelism, capacity_consumption):
    capacity_consumption = max(0.01, capacity_consumption)
    capacity_consumption = min(1.0, capacity_consumption)

    connection = Connection(host=host, region=region)
    desc = connection.describe_table(table_name)
    if desc is None:
        sys.stderr.writelines(["Table does not exist."])
        sys.exit(-1)
    total_items = desc['ItemCount']

    total_capacity = desc[PROVISIONED_THROUGHPUT][READ_CAPACITY_UNITS]
    capacity_per_process = max(
        1.0,
        (capacity_consumption * total_capacity) / float(parallelism)
    )

    queue = multiprocessing.Queue()
    pool = multiprocessing.Pool(
        processes=parallelism,
        initializer=dump_init,
        initargs=[queue, capacity_per_process, host, region, compress, table_name]
    )

    # Keys or segments?
    if len(hash_keys):
        for key in hash_keys:
            pool.apply_async(dump, [key])
        num_to_do = len(hash_keys)
    else:
        for segment in range(total_segments):
            pool.apply_async(dump, [segment, total_segments])
        num_to_do = total_segments

    num_complete = 0
    items_dumped = 0
    while True:
        sleep(1)
        while not queue.empty():
            update = queue.get()
            if update == 'complete':
                num_complete += 1
            else:
                items_dumped += update

        print("{}/~{} items dumped - {}/{} {}.".format(
            items_dumped,
            total_items,
            num_complete,
            num_to_do,
            "keys" if len(hash_keys) else "segments"
        ))

        if num_complete == num_to_do:
            break

    pool.close()
    pool.join()

    print("Done.")


host = None
queue = None
capacity = None
host = None
region = None
compress = None
table_name = None


def dump_init(_queue, _capacity, _host, _region, _compress, _table_name):
    global host, queue, capacity, host, region, compress, table_name
    queue = _queue
    capacity = _capacity
    host = _host
    region = _region
    compress = _compress
    table_name = _table_name


def dump(part, total_segments=None):
    """
    'part' may be the hash_key if we are dumping just a few hash_keys - else
    it will be the segment number
    """
    global host, queue, capacity, host, region, compress, table_name
    connection = Connection(host=host, region=region)

    filename = ".".join([table_name, str(part), "dump"])
    if compress:
        opener = gzip.GzipFile
        filename += ".gz"
    else:
        opener = open

    dumper = BatchDumper(connection, table_name, capacity, part, total_segments)

    with opener(filename, 'w') as output:
        while dumper.has_items:
            items = dumper.get_items()

            for item in items:
                output.write(json.dumps(item))
                output.write("\n")
            output.flush()

            queue.put(len(items))

    queue.put('complete')


class BatchDumper(object):
    def __init__(self, connection, table_name, capacity, part, total_segments):
        self.connection = connection
        self.table_name = table_name
        self.throttle = Throttle(capacity)
        self.part = part
        self.total_segments = total_segments
        self.has_items = True
        self.last_evaluated_key = None

    def get_items(self):
        data = self.get_data()

        capacity = data.get('ConsumedCapacity', {}).get('CapacityUnits', 0)
        self.throttle.add_record(capacity)

        self.last_evaluated_key = data.get(LAST_EVALUATED_KEY)
        self.has_items = (self.last_evaluated_key is not None)

        return data.get(ITEMS)

    def get_data(self):
        if self.total_segments is not None:
            return self.connection.scan(
                table_name=self.table_name,
                segment=self.part,
                total_segments=self.total_segments,
                exclusive_start_key=self.last_evaluated_key,
                return_consumed_capacity=TOTAL,
                limit=100,
            )
        else:
            return self.connection.query(
                table_name=self.table_name,
                hash_key=self.part,
                exclusive_start_key=self.last_evaluated_key,
                limit=100
            )

if __name__ == '__main__':
    kwargs = dict(parser.parse_args(sys.argv[1:])._get_kwargs())
    main(**kwargs)
