#!/usr/bin/env python
# coding=utf-8
from __future__ import print_function

import argparse
import gzip
import json
import multiprocessing
import os
import sys
from time import sleep

from pynamodb.connection import Connection
from pynamodb.constants import (BATCH_WRITE_PAGE_LIMIT, PUT_REQUEST,
                                UNPROCESSED_KEYS)


CPU_COUNT = multiprocessing.cpu_count()


parser = argparse.ArgumentParser(
    prog="dynamodb-loader",
    description="""DynamoDB Loader: restore tables dumped by dynamodb-dumper with ease."""
)
parser.add_argument(
    '-r',
    '--region',
    type=str,
    default="us-east-1",
    help="The region to connect to."
)
parser.add_argument(
    '-o',
    '--host',
    type=str,
    help="The host url to connect to (for use with DynamoDB Local)."
)
parser.add_argument(
    '-l',
    '--load-files',
    type=str,
    nargs='*',
    required=True,
    help="The list of filenames of dump files created by ddb-dumper that you wish to load."
)
parser.add_argument(
    '-p',
    '--parallelism',
    type=int,
    default=CPU_COUNT,
    help="The number of processes to use (defaults to the number of processors you have)."
)
parser.add_argument(
    'table_name',
    type=str,
    help="The name of the table to load into."
)


def main(host, region, table_name, parallelism, load_files):
    connection = Connection(host=host, region=region)
    desc = connection.describe_table(table_name)

    if desc is None:
        sys.stderr.writelines([
            "Table does not exist - please create first (you can use the official AWS CLI tool to do this automatically).\n"
        ])
        sys.exit(-1)

    check_files(load_files)

    desc = connection.describe_table(table_name)

    queue = multiprocessing.Queue()
    pool = multiprocessing.Pool(
        processes=parallelism,
        initializer=load_init,
        initargs=[queue, host, region, table_name]
    )
    for filename in load_files:
        pool.apply_async(load_file, [filename])

    files_completed = 0
    items_loaded = 0
    while True:
        sleep(1)
        while not queue.empty():
            update = queue.get()
            if update == 'complete':
                files_completed += 1
            else:
                items_loaded += update

        print("{} items loaded - {}/{} files complete.".format(
            items_loaded,
            files_completed,
            len(load_files)
        ))
        sys.stdout.flush()

        if files_completed == len(load_files):
            break

    pool.close()
    pool.join()

    print("Done.")


def check_files(load_files):
    if not len(load_files):
        sys.stderr.writelines([
            "No files to load from.\n"
        ])
        raise SystemExit(1)

    for load_file in load_files:
        if not os.path.exists(load_file):
            sys.stderr.writelines([
                "File {} does not exist.\n".format(load_file)
            ])
            raise SystemExit(1)


host = None
queue = None
region = None
table_name = None


def load_init(_queue, _host, _region, _table_name):
    global queue, host, region, table_name
    queue = _queue
    host = _host
    region = _region
    table_name = _table_name


def load_file(filename):
    global queue, host, region, table_name
    connection = Connection(host=host, region=region)

    if filename.endswith('.gz'):
        opener = gzip.GzipFile
    else:
        opener = open

    with opener(filename, 'r') as infile:
        with BatchPutManager(connection, table_name) as batch:
            for line in infile:
                item = json.loads(line)
                batch.put(item)
                queue.put(1)

    queue.put('complete')


class BatchPutManager(object):
    def __init__(self, connection, table_name):
        self.connection = connection
        self.table_name = table_name
        self.max_operations = BATCH_WRITE_PAGE_LIMIT
        self.items = []

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        return self.commit()

    def put(self, item):
        self.items.append(item)
        if len(self.items) == 25:
            self.commit()

    def commit(self):
        if not len(self.items):
            return

        unprocessed_keys = [{PUT_REQUEST: item} for item in self.items]

        while unprocessed_keys:
            items = []
            for key in unprocessed_keys:
                items.append(key.get(PUT_REQUEST))
            data = self.connection.batch_write_item(
                table_name=self.table_name,
                put_items=items,
            )
            unprocessed_keys = data.get(UNPROCESSED_KEYS, {}).get(self.table_name)

        self.items = []


if __name__ == '__main__':
    kwargs = dict(parser.parse_args(sys.argv[1:])._get_kwargs())
    main(**kwargs)
