#!/usr/bin/env python

# s3 bucket to bucket copy with validation (multithreaded)
# we use boto for amazon communication
# we support an s3cmd-compatible config file, and we use s3cmd in the tests, for validation,
# but there is no code dependency on s3cmd
#
# Note: throughout we use the s3 term "etag"
# This is equivalent to md5 EXCEPT for multipart files
# Multipart files have an etag that looks like a 32 byte md5 plus a dash
# and the number of parts
#
# see KeyInfo.is_multipart()

import argparse
import boto
import logging
import os
import Queue
import random
import re
import sys
import threading
import time
import traceback


start_time = time.time()
last_status_time = -1


def obfuscate_key(key):
    # return first and last four characters
    if key is None:
        return None
    return key[:4] + "..." + key[-4:]


def strip_etag(etag):
    return etag.replace('"', '')


def mk_etag_sig(etag):
    # we use only the first 7 characters
    return etag[:7]


class S3Uri:
    def __init__(self, uri_str):
        self.type = ""
        self._bucket = ""
        self._object = ""
        a = uri_str.split("://", 1)
        self.type = a[0]
        if len(a) < 2:
            return
        a = a[1].split("/", 1)
        self._bucket = a[0]
        if len(a) < 2:
            return
        self._object = a[1]

    def has_bucket(self):
        return len(self._bucket) > 0

    def bucket(self):
        return self._bucket

    def object(self):
        return self._object


class S3Config:
    def __init__(self, filename):
        self.access_key = None
        self.secret_key = None
        f = open(filename)
        for line in f:
            words = line.strip().split()
            if len(words) > 1:
                if words[0] == 'access_key':
                    self.access_key = words[-1]
                elif words[0] == 'secret_key':
                    self.secret_key = words[-1]
        f.close()


class Accumulator:
    def __init__(self):
        self.lock = threading.Lock()
        self.value = 0

    def increment(self, i):
        self.lock.acquire()
        self.value += i
        self.lock.release()

    def get_value(self):
        return self.value


def update_stats(i, accumulators):
    for acc in accumulators:
        acc.increment(i)


def progress_update():
    global last_status_time
    # never print status updates if this is a dry-run
    if options.dry_run:
        return
    CR_char = chr(13)
    now = time.time()
    secs = now - start_time
    # only print status updates at most 4 times/second to avoid
    # timing-related mess on the screen
    if secs - last_status_time < .25:
        return
    last_status_time = secs
    btotal = bytes_transferred.get_value()
    b2target = bytes_transferred_to_target.get_value()
    progress_template = ("%.2f secs %d bytes-to-target (%.2f MB/sec) %d "
                         "total-bytes-with-overhead (%.2f MB/sec)")
    progress_str = progress_template % (secs,
                                        b2target,
                                        b2target / secs / (1024 * 1024),
                                        btotal,
                                        btotal / secs / (1024 * 1024))
    sys.stdout.write(CR_char)
    sys.stdout.write(progress_str)
    sys.stdout.flush()


# This KeyInfo structure is used (rather than a boto.s3.bucket.Key)
# so we can instantiate them from a text file, such as output from s3cmd -ls
class KeyInfo:
    def __init__(self, date, size, etag, name):
        self.date = date
        self.size = size
        self.etag = strip_etag(etag)
        self.name = name

    @classmethod
    def fromKey(self, key):
        return KeyInfo(key.last_modified,
                       key.size,
                       strip_etag(key.etag),
                       key.name)

    def __str__(self):
        return "KeyInfo: date=%s size=%d etag=%s name=%s" % (self.date,
                                                             self.size,
                                                             self.etag,
                                                             self.name)

    def is_multipart(self):
        return self.etag != "unknown" and len(self.etag) != 32


class PartInfo:
    def __init__(self, seqno, start, end):
        self.seqno = seqno
        self.start = start
        self.end = end

    def __str__(self):
        return "PartInfo: seqno=%d start=%d end=%d" % (self.seqno,
                                                       self.start,
                                                       self.end)


class Source:
    def __init__(self, bucket, key_info, prefix, relative_path, part_info):
        self.bucket = bucket
        self.key_info = key_info
        self.prefix = prefix
        self.relative_path = relative_path
        self.part_info = part_info

    def withPartInfo(self, part_info):
        return Source(self.bucket,
                      self.key_info,
                      self.prefix,
                      self.relative_path,
                      part_info)

    def __str__(self):
        return ("Source: bucket=%s key_info=%s prefix=%s relative_path=%s "
                "part_info=%s" % (self.bucket,
                                  self.key_info,
                                  self.prefix,
                                  self.relative_path,
                                  self.part_info))


class Dest:
    def __init__(self, bucket, readable_bucket, prefix):
        self.bucket = bucket
        self.readable_bucket = readable_bucket
        self.prefix = prefix

    def __str__(self):
        return ("Dest: bucket=%s readable_bucket=%s "
                "prefix=%s" % (self.bucket,
                               self.readable_bucket,
                               self.prefix))


class CopyTask:
    def __init__(self, source, dest):
        self.source = source
        self.dest = dest

    def __str__(self):
        return "CopyTask: source=%s dest=%s" % (self.source, self.dest)


def get_key(bucket, keyname):
    """Safely fetch a key from a bucket
    """
    k = None
    try:
        for k in bucket.list(keyname):
            break
    except boto.exception.S3ResponseError as e:
        if e.status == 403:
            logger.info("List access denied to file "
                        "s3://%s/%s" % (bucket.name, keyname))
            return None
        else:
            raise
    return k


def get_keyinfo(bucket, keyname):
    """Safely fetch a KeyInfo from a bucket
    """
    k = None
    try:
        for k in bucket.list(keyname):
            break
    except boto.exception.S3ResponseError as e:
        if e.status == 403:
            logger.info("List access denied to file s3://%s/%s" % (bucket.name,
                                                                   keyname))
            return KeyInfo(None, 0, "unknown", keyname)
        else:
            raise
    if k is None:
        return None
    return KeyInfo.fromKey(k)


def should_split(ki):
    return ki.is_multipart() or (ki.size > options.large_file_size)


def grant_acl(key, bucketname):
    if options.acl_grant is None:
        return
    permission, email_address = options.acl_grant
    logger.debug("Try to grant permission %s to email %s for file s3://%s/%s" %
                 (permission, email_address, bucketname, key.name))
    try:
        key.add_email_grant(permission, email_address)
        logger.info("Granted permission %s to email %s for file s3://%s/%s" %
                    (permission, email_address, bucketname, key.name))
    except boto.exception.S3ResponseError as e:
        log_error("Couldn't grant permission %s to email %s for file "
                  "s3://%s/%s. Error: %s" % (permission,
                                             email_address,
                                             bucketname,
                                             key.name,
                                             e))


def grant_acl_mp(done_mp):
    if options.acl_grant is None:
        return
    logger.debug("Try to grant permission for completed multipart %s %s %s" %
                 (done_mp.bucket_name, done_mp.key_name, done_mp.etag))
    k = get_key(done_mp.bucket, done_mp.key_name)
    if k is None:
        log_error("Couldn't grant permission %s to email %s for file "
                  "s3://%s/%s. Can't list file in bucket" % (permission,
                                                             email_address,
                                                             bucketname,
                                                             key.name))
        return
    grant_acl(k, done_mp.bucket_name)


def copy_key(from_bucket_name, from_key_name, to_bucket, to_key_name,
             etag, size, readable_to_bucket, accumulators):
    """Copy a file to a bucket

    from_bucket_name -- string name of the source bucket
    from_key_name -- string name of the source file
    to_bucket -- boto.s3.bucket.Bucket object for the destination
    to_key_name -- string name to be used for the copy
    etag -- expected etag result from the copy -- see note below
    size -- the size of the file to copy
    readable_to_bucket -- boto.s3.bucket.Bucket object that is readable
                          (listable) may be under a different account
    accumulators -- stats accumulators

    Copies the source file to the destination. The source and destination
    buckets may be the same; if so, the caller should ensure that the source
    and destination file names are different.

    Tests the etag from the copy against the supplied etag and retries the
    copy if necessary up to the configured number of times (options.retries),
    until it matches.

    Note: the etag supplied may not be the etag of the corresponding source
    file, in the case that the source was a multipart file (len(etag) != 32).

    Returns True or False, to indicate success
    """

    log_msg_header = ("Copy file s3://%s/%s to file "
                      "s3://%s/%s" % (from_bucket_name,
                                      from_key_name,
                                      to_bucket.name,
                                      to_key_name))
    logger.debug("%s: dst_bucket_readable: %s" %
                 (log_msg_header, "True" if readable_to_bucket else "False"))
    # For restartability: check to see whether we've already created this
    # part file.
    # Note: the parts are tagged by (a portion of) the etag for parent, so if
    # the parent has changed, we'll rebuild them all
    if readable_to_bucket:
        logger.debug("%s: looking for file in dest" % log_msg_header)
        to_keyinfo = get_keyinfo(readable_to_bucket, to_key_name)
        if to_keyinfo is not None:
            if to_keyinfo.etag == etag:
                logger.info("Validate s3://%s/%s copy s3://%s/%s %s already "
                            "exists" % (from_bucket_name,
                                        from_key_name,
                                        to_bucket.name,
                                        to_key_name,
                                        etag))
                logger.debug("%s: found dest file found with etag %s. "
                             "Skipping." % (log_msg_header, etag))
                return to_keyinfo
            else:
                logger.debug("%s: found dest file found with different etags: "
                             "source %s dest %s. Copy continues" %
                             (log_msg_header, etag, to_keyinfo.etag))

    if options.dry_run:
        print log_msg_header
        return KeyInfo(None, size, etag, to_key_name)

    logger.debug("%s: expecting etag %s start" % (log_msg_header, etag))
    success = False
    retries = 0
    while (not success and retries <= options.retries):
        retries += 1
        try:
            result_key = to_bucket.copy_key(to_key_name,
                                            from_bucket_name,
                                            from_key_name)
            update_stats(size, accumulators)
            new_etag = strip_etag(result_key.etag)
            if etag == "unknown" or new_etag == etag:
                success = True
                logger.debug("%s: success expected etag %s "
                             "got %s" % (log_msg_header, etag, new_etag))
            else:
                logger.error("%s: mismatch expected etag %s got "
                             "%s" % (log_msg_header, etag, new_etag))
        except Exception as e:
            if e.status == 403:
                log_error("Error: Access denied to copy file s3://%s/%s - "
                          "skipping" % (from_bucket_name, from_key_name))
                return None
            else:
                logger.error("%s: error: %s\n%s" % (log_msg_header,
                                                    e,
                                                    traceback.format_exc()))
    if success:
        logger.info("Validate s3://%s/%s copy s3://%s/%s %s "
                    "complete" % (from_bucket_name,
                                  from_key_name,
                                  to_bucket.name,
                                  to_key_name,
                                  new_etag))
        logger.info("%s: success etag %s" % (log_msg_header, new_etag))
        grant_acl(result_key, to_bucket.name)
        return KeyInfo.fromKey(result_key)
    else:
        logger.error("%s: giving up" % log_msg_header)
        return None


def copy_part(from_bucket_name, from_key_name, start, end, to_bucket,
              to_key_name, accumulators):
    """Copy a part of a source file to a bucket as a single, multipart file.

    from_bucket_name -- string name of the source bucket
    from_key_name -- string name of the source file
    start -- 0-based offset of first byte to copy
    end -- 0-based offset of last byte to copy
    to_bucket -- boto.s3.bucket.Bucket object
    to_key_name -- string name to be used for the copy
    accumulators -- stats accumulators

    Copies the specified part of the the specified source file to the
    destination, as a single file.  The source and destination buckets may
    be the same; if so, the caller should ensure that the source and
    destination file names are different.

    We get the etag of the part - which is a normal md5 - as a result of the
    s3 copy_part_from_key call.  Note that this is _different_ from the
    etag of the result file, which will end in -1 (since the result is a
    multipart file with a single part).

    Returns (True, etag) or (False, _) depending on success
    """
    size = end - start + 1
    log_msg_header = ("Copy part s3://%s/%s[%d,%d] to file "
                      "s3://%s/%s" % (from_bucket_name,
                                      from_key_name,
                                      start,
                                      end,
                                      to_bucket.name,
                                      to_key_name))
    logger.debug("%s: start" % log_msg_header)
    success = False
    retries = 0
    while (not success and retries <= options.retries):
        retries += 1
        this_etag = ""
        try:
            mp = to_bucket.initiate_multipart_upload(to_key_name)
            part_key = mp.copy_part_from_key(from_bucket_name,
                                             from_key_name,
                                             1,
                                             start=start,
                                             end=end)
            update_stats(size, accumulators)
            this_etag = strip_etag(part_key.etag)
            logger.debug("%s: copy part success etag %s" % (log_msg_header,
                                                            this_etag))
            done_mp = mp.complete_upload()
            logger.info("Validate s3://%s/%s temp s3://%s/%s [%d,%d] %s "
                        "complete" % (from_bucket_name,
                                      from_key_name,
                                      to_bucket.name,
                                      to_key_name,
                                      start,
                                      end,
                                      this_etag))
            logger.debug("%s: success etag %s" % (log_msg_header, this_etag))
            grant_acl_mp(done_mp)
            success = True
        except Exception as e:
            logger.error("%s: error: %s\n%s" % (log_msg_header,
                                                e,
                                                traceback.format_exc()))
    if success:
        logger.info("%s: etag %s" % (log_msg_header, this_etag))
    else:
        logger.error("%s: giving up" % log_msg_header)
    return (success, this_etag)


def make_partsdir(keyinfo):
    return "%s/%s/parts/" % (keyinfo.name, mk_etag_sig(keyinfo.etag))


def make_mp_tmp_name():
    return "temp-mp"


def copy_part_validate(source, dest):
    """Copy part of a source file to a bucket as a single, non-multipart file.

    source -- Source object (partinfo must be specified)
    dest -- Dest object

    We first make a multipart copy of the single part to a temp file, to
    establish the md5 of this part.  We then make a single part copy of
    that temp file, so that we have a simple file with the correct etag.
    This makes it easy to restart.

    If we fail, we return None

    If we succeed, we clean up the work files, then
    return a Source object representing the new single-part file (for
    further copying)
    """
    base_to_key_name = (make_partsdir(source.key_info) +
                        ("part-%05d" % source.part_info.seqno))
    to_key_name = dest.prefix + base_to_key_name
    log_msg_header = ("Copy part s3://%s/%s[%d,%d] to file s3://%s/ "
                      "%s %s" % (source.bucket.name,
                                 source.key_info.name,
                                 source.part_info.start,
                                 source.part_info.end,
                                 dest.bucket.name,
                                 dest.prefix,
                                 base_to_key_name))
    logger.debug("%s: start" % log_msg_header)

    size = source.part_info.end - source.part_info.start + 1

    # check to see whether we've already created this part file
    if dest.readable_bucket:
        to_keyinfo = get_keyinfo(dest.readable_bucket, to_key_name)
        if to_keyinfo is not None:
            logger.info("%s: part file found. Skipping." % log_msg_header)
            newSource = Source(dest.bucket,
                               to_keyinfo,
                               dest.prefix,
                               base_to_key_name,
                               None)
            return newSource

    if options.dry_run:
        print log_msg_header
        # make a fake keyinfo
        keyinfo = KeyInfo("now", size, "unknown", to_key_name)
        newSource = Source(dest.bucket,
                           keyinfo,
                           dest.prefix,
                           base_to_key_name,
                           None)
        return newSource

    # copy to the temp file first
    base_temp_key = (make_partsdir(source.key_info) +
                     ("temp-%05d" % source.part_info.seqno))
    temp_key = dest.prefix + base_temp_key

    (success, this_etag) = copy_part(source.bucket.name,
                                     source.key_info.name,
                                     source.part_info.start,
                                     source.part_info.end,
                                     dest.bucket,
                                     temp_key,
                                     [bytes_transferred])
    if not success:
        return None

    logger.info("%s: temp copy succeeded." % log_msg_header)
    # make the one-part copy of the temp file
    keyinfo = copy_key(dest.bucket.name,
                       temp_key,
                       dest.bucket,
                       to_key_name,
                       this_etag,
                       size,
                       dest.readable_bucket,
                       [bytes_transferred])

    if keyinfo is None:
        return None

    if keyinfo.size is None:
        keyinfo.size = size
    logger.info("%s: one-part copy succeeded %s." %
                (log_msg_header, str(keyinfo)))
    newSource = Source(dest.bucket,
                       keyinfo,
                       dest.prefix,
                       base_to_key_name,
                       None)
    return newSource


def do_copy_part_validate(dest, output_q, source):
    partSource = copy_part_validate(source, dest)
    if partSource:
        logger.debug("do_copy_part_validate source %s" % str(source))
        logger.debug("do_copy_part_validate partSource %s" % str(partSource))
        output_q.put((source, partSource))


def split_worker_action(input_q, output_q, dest):
    # keep picking up work items until the queue is empty
    while not input_q.empty():
        try:
            source = input_q.get(True, 1)
        except Queue.Empty:
            logger.warn("Queue is empty")
            return
        do_copy_part_validate(dest, output_q, source)
        input_q.task_done()
        progress_update()


def do_split(input_q, output_q, dest):
    def thread_worker():
        split_worker_action(input_q, output_q, dest)
    return thread_worker


def fill_split_queue(part_size, q, sources):
    for source in sources:
        keyinfo = source.key_info
        if should_split(keyinfo):
            numparts = ((keyinfo.size - 1) // part_size) + 1
            for i in range(0, numparts):
                start = i * part_size
                end = min(((i + 1) * part_size) - 1, keyinfo.size - 1)
                s = source.withPartInfo(PartInfo(i + 1, start, end))
                q.put(s)


def split_phase(sources, dest, splitresults):
    sys.stdout.write("Split phase\n")
    split_work = Queue.Queue()
    split_results = Queue.Queue()

    fill_split_queue(options.part_size, split_work, sources)
    if split_work.qsize() == 0:
        return

    num_threads = min(options.num_threads, split_work.qsize())
    # for debugging, we can run without the thread infrastructure
    if num_threads == 0:
        split_worker_action(split_work, split_results, dest)
    else:
        workers = []
        for i in range(num_threads):
            worker = threading.Thread(target=do_split(split_work,
                                                      split_results,
                                                      dest))
            logger.debug("starting worker %d" % i)
            worker.daemon = True
            worker.start()
            workers.append(worker)
        while not split_work.empty():
            progress_update()
            time.sleep(.25)
        split_work.join()

    logger.debug("Validating split")

    # build a tree of results: keyinfo.name -> partinfo.seqno
    while not split_results.empty():
        (source, partSource) = split_results.get(True, 1)
        keyinfo = source.key_info
        partinfo = source.part_info
        if not keyinfo.name in splitresults:
            parts = dict()
            source.part_info = None
            splitresults[keyinfo.name] = (source, parts)
        _, parts = splitresults[keyinfo.name]
        parts[partinfo.seqno] = (partinfo, partSource)

    logger.debug("Splitresults")
    for name in splitresults:
        source, parts = splitresults[name]
        logger.debug("Splitresults name %s source %s" % (name, str(source)))
        for seqno in sorted(parts):
            (partinfo, partSource) = parts[seqno]
            logger.debug("   %s %s" % (str(partinfo), str(partSource)))

    logger.debug("Completed split phase")
    sys.stdout.write("\nMain copy phase\n")


def copy_worker_action(input_q):
    logger.debug("Copying worker start")
    while not input_q.empty():
        try:
            copy_task = input_q.get(True, 1)
        except Queue.Empty:
            logger.warn("Queue is empty")
            return
        from_bucket_name = copy_task.source.bucket.name
        from_key_name = copy_task.source.key_info.name
        to_key_name = copy_task.dest.prefix + copy_task.source.relative_path
        accumulators = [bytes_transferred, bytes_transferred_to_target]
        to_bucket = copy_task.dest.bucket
        etag = copy_task.source.key_info.etag
        size = copy_task.source.key_info.size
        readable_to_bucket = copy_task.dest.readable_bucket
        copy_key(from_bucket_name, from_key_name, to_bucket,
                 to_key_name, etag, size, readable_to_bucket, accumulators)
        input_q.task_done()
        progress_update()


def do_copy(input_q):
    def thread_worker():
        copy_worker_action(input_q)
    return thread_worker


def fill_copy_queue(q, sources, splitresults, dest):
    for source in sources:
        if not should_split(source.key_info):
            q.put(CopyTask(source, dest))


def copy_phase(sources, splitresults, dest):
    copying_work = Queue.Queue()
    fill_copy_queue(copying_work, sources, splitresults, dest)
    if copying_work.qsize() == 0:
        return

    num_threads = min(options.num_threads, copying_work.qsize())
    # for debugging, we can run without the thread infrastructure
    if num_threads == 0:
        copy_worker_action(copying_work)
    else:
        workers = []
        for i in range(num_threads):
            worker = threading.Thread(target=do_copy(copying_work))
            logger.debug("starting worker %d" % i)
            worker.daemon = True
            worker.start()
            workers.append(worker)
        while not copying_work.empty():
            progress_update()
            time.sleep(.25)
        copying_work.join()

    logger.debug("Completed copy phase")


def copy_part_to_mp(mp_dst_key_name, bucket, mp, partinfo, partSource):
    from_key_name = partSource.key_info.name
    etag = partSource.key_info.etag
    seqno = partinfo.seqno
    log_msg_header = ("Reassemble file s3://%s/%s[%d] from "
                      "s3://%s/%s" % (bucket.name,
                                      mp_dst_key_name,
                                      seqno,
                                      bucket.name,
                                      from_key_name))
    logger.debug("%s: start" % log_msg_header)
    if options.dry_run:
        print log_msg_header
        return
    end = partinfo.end - partinfo.start
    size = end + 1
    this_etag = ""
    success = False
    retries = 0
    while (not success and retries <= options.retries):
        retries += 1
        try:
            part_key = mp.copy_part_from_key(bucket.name,
                                             from_key_name,
                                             seqno,
                                             start=0,
                                             end=end)
            update_stats(size, [bytes_transferred])
            new_etag = strip_etag(part_key.etag)
            if new_etag == etag:
                success = True
                logger.debug("%s: success etag %s" % (log_msg_header, etag))
            else:
                logger.error("%s: mismatch expected etag %s got "
                             "%s" % (log_msg_header, etag, new_etag))
            logger.debug("%s: copy part success etag %s" % (log_msg_header,
                                                            this_etag))
        except Exception as e:
            logger.error("%s: error: %s\n%s" % (log_msg_header,
                                                e,
                                                traceback.format_exc()))
    if success:
        logger.info("%s: etag %s" % (log_msg_header, etag))
    else:
        logger.error("%s: giving up" % log_msg_header)
    logger.debug("%s: end" % log_msg_header)


def reassembly_worker_action(dst_bucket, input_q):
    while not input_q.empty():
        try:
            (mp_dst_key_name, mp, partinfo, partSource) = input_q.get(True, 1)
        except Queue.Empty:
            logger.warn("Queue is empty")
            return
        copy_part_to_mp(mp_dst_key_name, dst_bucket, mp, partinfo, partSource)
        input_q.task_done()
        progress_update()


def do_reassembly(dst_bucket, input_q):
    def thread_worker():
        reassembly_worker_action(dst_bucket, input_q)
    return thread_worker


def fill_reassembly_queue(q, mp_uploads, splitresults, dest, work_dest):
    for name in splitresults:
        source, parts = splitresults[name]
        keyinfo = source.key_info
        logger.debug(str(keyinfo))
        if len(parts) == 1:
            # we'll collect these at the very end
            continue
        if keyinfo.size <= options.reassemble_large_file_size:
            # reassemble the parts to a multipart temp file, so we can later
            # copy them into a single-part file
            mp_dst_key_name = work_dest.prefix + \
                make_partsdir(keyinfo) + make_mp_tmp_name()
        else:
            mp_dst_key_name = dest.prefix + keyinfo.name
        if options.dry_run:
            mp = None
        else:
            logger.debug("Started mp upload %s", mp_dst_key_name)
            mp = dest.bucket.initiate_multipart_upload(mp_dst_key_name)
        mp_uploads[keyinfo.name] = mp
        for seqno in parts:
            (partinfo, partSource) = parts[seqno]
            q.put((mp_dst_key_name, mp, partinfo, partSource))


def reassembly_phase(splitresults, dest, work_dest):
    sys.stdout.write("\nReassembly phase\n")
    reassembly_work = Queue.Queue()
    mp_uploads = dict()

    fill_reassembly_queue(reassembly_work,
                          mp_uploads,
                          splitresults,
                          dest,
                          work_dest)
    logger.debug("reassembly work q size is %d" % reassembly_work.qsize())
    if reassembly_work.qsize() == 0:
        return

    num_threads = min(options.num_threads, reassembly_work.qsize())
    # for debugging, we can run without the thread infrastructure
    if num_threads == 0:
        reassembly_worker_action(dest.bucket, reassembly_work)
    else:
        workers = []
        for i in range(num_threads):
            worker = threading.Thread(target=do_reassembly(dest.bucket,
                                                           reassembly_work))
            logger.debug("starting worker %d" % i)
            worker.daemon = True
            worker.start()
            workers.append(worker)
        while not reassembly_work.empty():
            progress_update()
            time.sleep(.25)
        reassembly_work.join()

    if not options.dry_run:
        for name in mp_uploads:
            logger.debug("completing upload %s" % mp_uploads[name].key_name)
            # TODO: check for success here
            done_mp = mp_uploads[name].complete_upload()
            logger.debug("completed upload %s" % name)
            grant_acl_mp(done_mp)


def copy_key_no_validation(from_bucket_name, from_key_name, to_bucket,
                           to_key_name, accumulators):
    log_msg_header = ("Copy key s3://%s/%s to file "
                      "s3://%s/%s" % (from_bucket_name,
                                      from_key_name,
                                      to_bucket.name,
                                      to_key_name))
    logger.debug("%s: begin" % log_msg_header)

    if options.dry_run:
        print log_msg_header
        return True

    success = False
    retries = 0
    while (not success and retries <= options.retries):
        retries += 1
        try:
            result_key = to_bucket.copy_key(to_key_name,
                                            from_bucket_name,
                                            from_key_name)
            new_etag = strip_etag(result_key.etag)
            logger.debug("%s: success etag %s " % (log_msg_header, new_etag))
            success = True
        except Exception as e:
            logger.error("%s: error: %s\n%s" % (log_msg_header,
                                                e,
                                                traceback.format_exc()))
    if success:
        logger.info("%s: etag %s" % (log_msg_header, new_etag))
        grant_acl(result_key, to_bucket.name)
    else:
        logger.error("%s: giving up" % log_msg_header)
    return success


def final_copy_worker_action(input_q):
    while not input_q.empty():
        try:
            (from_bucket_name, from_key_name, to_bucket,
             to_key_name) = input_q.get(True, 1)
        except Queue.Empty:
            logger.warn("Queue is empty")
            return
        accumulators = [bytes_transferred]
        copy_key_no_validation(from_bucket_name,
                               from_key_name,
                               to_bucket,
                               to_key_name,
                               accumulators)
        input_q.task_done()
        progress_update()


def do_final_copy(input_q):
    def thread_worker():
        final_copy_worker_action(input_q)
    return thread_worker


def fill_final_queue(q, splitresults, dest, work_dest):
    for name in splitresults:
        source, parts = splitresults[name]
        keyinfo = source.key_info
        logger.debug(str(keyinfo))
        if len(parts) == 1:
            (partinfo, partSource) = parts[1]
            src_key = partSource.key_info.name
        elif keyinfo.size <= options.reassemble_large_file_size:
            src_key = work_dest.prefix + \
                make_partsdir(keyinfo) + make_mp_tmp_name()
        # the big files are already done
        else:
            continue
        dst_key_name = dest.prefix + source.key_info.name
        q.put((dest.bucket.name, src_key, dest.bucket, dst_key_name))


def final_phase(splitresults, dest, work_dest):
    final_work = Queue.Queue()

    fill_final_queue(final_work, splitresults, dest, work_dest)
    logger.debug("final work q size is %d" % final_work.qsize())
    if final_work.qsize() == 0:
        return

    sys.stdout.write("\nFinal phase\n")

    num_threads = min(options.num_threads, final_work.qsize())
    # for debugging, we can run without the thread infrastructure
    if num_threads == 0:
        final_copy_worker_action(final_work)
    else:
        workers = []
        for i in range(num_threads):
            worker = threading.Thread(target=do_final_copy(final_work))
            logger.debug("starting worker %d" % i)
            worker.daemon = True
            worker.start()
            workers.append(worker)
        while not final_work.empty():
            progress_update()
            time.sleep(.25)
        final_work.join()


def setup_options():
    global options
    parser = argparse.ArgumentParser(description="Multithreaded multipart "
                                     "copier for Amazon S3")

    parser.add_argument("source_bucket", help="source bucket/path")
    parser.add_argument("dest_bucket", nargs='?',
                        help="destination bucket/path")

    parser.add_argument("-n", "--dry-run", dest="dry_run", action='store_true',
                        default=False,
                        help="do no work but report what work would be done")

    parser.add_argument("-f", "--file", nargs="+",
                        help="source file[s] to copy")
    parser.add_argument("-p", "--prefix", nargs="+",
                        help="source prefix[es] to copy")

    parser.add_argument("-F", "--files",
                        help="file containing a list of files to copy")
    parser.add_argument("-P", "--prefixes",
                        help="file containing a list of prefixes to copy")

    parser.add_argument("-a", help="AWS Access Key", dest="aws_access_key")
    parser.add_argument("-k", help="AWS Secret Key", dest="aws_secret_key")
    parser.add_argument("-c", "--config_file", help="s3cmd-format config file",
                        dest="s3cfg_file", default="~/.s3cfg")
    parser.add_argument("-d", "--dest-config",
                        dest="dest_s3cfg_file", default=None,
                        help="s3cmd-format config file for destination bucket "
                             "only")

    parser.add_argument("--acl-grant", dest="acl_grant",
                        default=None, help="acl to grant as PERMISSION:EMAIL")

    parser.add_argument("-t", help="number of threads (default: 40)",
                        dest="num_threads", type=int, default=40)

    parser.add_argument("-l", help="logging level",
                        dest="log_level", default="CRITICAL")
    parser.add_argument("-L", help="logging file (appended)",
                        dest="log_dest", default="STDOUT")

    # TODO: test and support these options
    parser.add_argument("--clean-mpcopies", dest="clean_mpcopies",
                        action='store_true', default=False,
                        help=argparse.SUPPRESS)
    parser.add_argument("--s3-prefixes", action='store_true',
                        dest="s3_prefixes", default=False,
                        help=argparse.SUPPRESS)
    parser.add_argument("-w", "--work-bucket", help=argparse.SUPPRESS)
    parser.add_argument("--retries", dest="retries",
                        type=int, default=3, help=argparse.SUPPRESS)
    parser.add_argument("--dest-prefix", help=argparse.SUPPRESS)
    parser.add_argument("--work-prefix", help=argparse.SUPPRESS)
    parser.add_argument("--reassemble-large-file", help=argparse.SUPPRESS,
                        dest="reassemble_large_file_size", type=int,
                        default=5 * (2 ** 30))  # 5GB
    parser.add_argument("--large-file", help=argparse.SUPPRESS,
                        dest="large_file_size", type=int,
                        default=5 * (2 ** 30))  # 5GB
    parser.add_argument("--part-size", help=argparse.SUPPRESS,
                        dest="part_size", type=int, default=2 ** 26)  # 64 Meg

    options = parser.parse_args()


def read_s3cfg(s3cfg_file):
    expanded_file = os.path.expanduser(s3cfg_file)
    if not os.path.isfile(expanded_file):
        log_error("Error: s3 config file %s not found" % s3cfg_file)
        return (None, None)
    s3config = S3Config(os.path.expanduser(expanded_file))
    cfg_access_key = s3config.access_key
    cfg_secret_key = s3config.secret_key
    logger.debug("read_s3cfg: access %s secret %s" %
                 (obfuscate_key(cfg_access_key),
                  obfuscate_key(cfg_secret_key)))
    return (cfg_access_key, cfg_secret_key)


def init_boto(aws_access_key, aws_secret_key):
    logger.debug("Connecting to S3 access %s secret %s" %
                 (obfuscate_key(aws_access_key),
                  obfuscate_key(aws_secret_key)))
    logger.info("Connecting to S3")
    return boto.connect_s3(aws_access_key, aws_secret_key)


def log_error(msg):
    global global_return_code
    global_return_code = 1
    logger.error(msg)
    sys.stderr.write(msg + "\n")


def validate_bucket(bucket_uri_string, bucket_type, s3, connection_valid):
    bucket_valid = False
    uri = None
    try:
        uri = S3Uri(bucket_uri_string)
        bucket_valid = uri.type == "s3" and uri.has_bucket()
    except ValueError:
        bucket_valid = False

    if not bucket_valid:
        log_error("Parameter error: %s bucket must be an S3 URI. Found: %s" %
                  (bucket_type, bucket_uri_string))
        return (bucket_valid, connection_valid, None, None, False)

    bucket_name = uri.bucket()

    if not connection_valid:
        # we discovered an invalid connection earlier - we're just validating
        # the form of the uri
        return (bucket_valid, connection_valid, None, uri.object(), False)

    bucket = None
    try:
        bucket = s3.get_bucket(bucket_name)
        bucket_valid = True
        bucket_readable = True
    except boto.exception.S3ResponseError as e:
        bucket_readable = False
        if e.error_code == 'InvalidAccessKeyId':
            connection_valid = False
            log_error("Parameter error: invalid access key")
        elif e.error_code == 'SignatureDoesNotMatch':
            connection_valid = False
            log_error("Parameter error: invalid secret key")
        elif e.error_code == 'NoSuchBucket':
            log_error("Parameter error: source bucket %s does not exist" %
                      bucket_name)
        elif e.error_code == 'AccessDenied':
            logger.info("Source bucket %s access denied" % bucket_name)
            bucket_valid = True
            bucket = s3.get_bucket(bucket_name, validate=False)
        else:
            raise
    return (bucket_valid,
            connection_valid,
            bucket,
            uri.object(),
            bucket_readable)


def main():
    global logger
    global bytes_transferred, bytes_transferred_to_target
    global global_return_code

    global_return_code = 0

    setup_options()

    logger = logging.getLogger("S3 MultiThreaded Copy")
    if options.log_dest == "STDOUT":
        log_dst = logging.StreamHandler(sys.stdout)
    else:
        log_dst = logging.FileHandler(options.log_dest)

    formatter = logging.Formatter('%(asctime)-15s: %(levelname)-8s '
                                  '(message)s')
    log_dst.setFormatter(formatter)
    log_dst.setLevel(getattr(logging, options.log_level.upper()))
    logger.addHandler(log_dst)
    logger.setLevel(getattr(logging, options.log_level.upper()))
    # this is set to false because s3cmd Config screws with the top-level
    # logger somehow
    logger.propagate = False

    bytes_transferred = Accumulator()
    bytes_transferred_to_target = Accumulator()

    # parameter and configuration validation
    # we try to validate as much as possible (not give up after the first
    # problem)
    args_valid = True

    # find the primary credentials (a second set may be present for the
    # destination bucket)
    (access_key, secret_key) = read_s3cfg(options.s3cfg_file)
    if access_key is None:
        access_key = options.aws_access_key
    if secret_key is None:
        secret_key = options.aws_secret_key

    if secret_key is None or access_key is None:
        s3 = None
        connection_valid = False
    else:
        s3 = init_boto(access_key, secret_key)
        connection_valid = True  # until we find otherwise

    # validate the source bucket. this may also invalidate the connection
    source_bucket_name = options.source_bucket
    _ = validate_bucket(source_bucket_name, "source", s3, connection_valid)
    src_bucket_valid = _[0]
    connection_valid = _[1]
    src_bucket = _[2]
    src_path = _[3]
    src_bucket_readable = _[4]

    if not src_bucket_valid or not connection_valid:
        args_valid = False

    # TODO: This cleanup should be moved somewhere else and better supported
    if options.clean_mpcopies:
        for mp in src_bucket.list_multipart_uploads():
            logger.info("Canceling mp upload %s", mp.key_name)
            mp.cancel_upload()
        sys.exit(global_return_code)

    # if no destination bucket is specified, we use the source,
    if options.dest_bucket is None:
        dst_bucket_valid = src_bucket_valid
        dst_bucket = src_bucket
        dst_path = ""
        dst_bucket_readable = src_bucket_readable
        dst_bucket_name = dst_bucket.name
    else:
        # validate the destination bucket. this may also invalidate the
        # connection
        dst_bucket_name = options.dest_bucket
        _ = validate_bucket(dst_bucket_name,
                            "destination",
                            s3,
                            connection_valid)
        dst_bucket_valid = _[0]
        connection_valid = _[1]
        dst_bucket = _[2]
        dst_path = _[3]
        dst_bucket_readable = _[4]

        if not dst_bucket or not dst_bucket_valid or not connection_valid:
            args_valid = False

    logger.debug("dst_path %s" % dst_path)

    # we combine the path portion of the bucket uri (if any) with the
    # dest_prefix parameter (if any)
    if options.dest_prefix is None:
        if len(dst_path) > 0:
            dst_prefix = dst_path
        elif src_bucket_valid and dst_bucket_name == source_bucket_name:
            dst_prefix = "temp/dest/"
        else:
            dst_prefix = ""
    else:
        dst_prefix = dst_path + options.dest_prefix

    # add a trailing slash to the destination (unless s3 semantics requested)
    if (len(dst_prefix) > 0 and
            dst_prefix[-1] != "/" and not
            options.s3_prefixes):
        dst_prefix += "/"

    logger.debug("dst_prefix %s" % dst_prefix)

    # if no work bucket is specified, we use the destination
    if options.work_bucket is None:
        # note: if no dest was specified, then this is the source also
        work_bucket_valid = dst_bucket_valid
        work_bucket = dst_bucket
        work_path = ""
        work_bucket_name = source_bucket_name
    else:
        # validate the work bucket. this may also invalidate the connection
        work_bucket_name = options.work_bucket
        _ = validate_bucket(work_bucket_name, "work", s3, connection_valid)
        work_bucket_valid = _[0]
        connection_valid = _[1]
        work_bucket = _[2]
        work_path = _[3]

        if not work_bucket or not work_bucket_valid or not connection_valid:
            args_valid = False

    logger.debug("src_bucket_valid %s" % str(src_bucket_valid))

    if options.work_prefix is None:
        if (len(work_path)) > 0:
            work_prefix = work_path
        if (src_bucket_valid and
                (work_bucket_name == source_bucket_name
                    or work_bucket_name == dst_bucket_name)):
            work_prefix = "temp/"
        else:
            work_prefix = ""
    else:
        work_prefix = work_path + options.work_prefix

    if dst_bucket_valid:
        if dst_bucket_readable:
            readable_dst_bucket = dst_bucket
        else:
            readable_dst_bucket = None
            secondary_access_key = None
            secondary_secret_key = None
            # try to find another set of credentials that can read the dest
            # bucket
            if not options.dest_s3cfg_file is None:
                logger.debug("Trying specified dest-config: %s" %
                             options.dest_s3cfg_file)
                (secondary_access_key, secondary_secret_key) = read_s3cfg(
                    options.dest_s3cfg_file)
            elif not options.s3cfg_file == "~/.s3cfg":
                # if special primary creds were used, try these default creds
                # as secondary
                logger.debug("Trying default s3cfg as dest-config: %s" %
                             "~/.s3cfg")
                (secondary_access_key,
                 secondary_secret_key) = read_s3cfg("~/.s3cfg")
            if secondary_access_key is None or secondary_secret_key is None:
                logger.debug("Dest bucket is unreadable and can't find "
                             "secondary access %s, %s" %
                             (obfuscate_key(secondary_access_key),
                              obfuscate_key(secondary_secret_key)))
            else:
                secondary_s3 = init_boto(secondary_access_key,
                                         secondary_secret_key)
                _ = validate_bucket(options.dest_bucket,
                                    "destination using secondary credentials",
                                    secondary_s3,
                                    True)
                secondary_connection_valid = _[1]
                readable_dst_bucket = _[2]
                dst_bucket_readable = _[4]

                if not secondary_connection_valid:
                    logger.debug("Dest bucket is unreadable but secondary "
                                 "credentials are invalid  %s, %s" %
                                 (obfuscate_key(secondary_access_key),
                                  obfuscate_key(secondary_secret_key)))
                elif not dst_bucket_readable:
                    logger.debug("Dest bucket is unreadable but secondary "
                                 "credentials also don't have access %s, %s" %
                                 (obfuscate_key(secondary_access_key),
                                  obfuscate_key(secondary_secret_key)))
                else:
                    logger.debug("Dest bucket is unreadable but secondary "
                                 "credentials work: %s, %s" %
                                 (obfuscate_key(secondary_access_key),
                                  obfuscate_key(secondary_secret_key)))

    if options.acl_grant is not None:
        try:
            [permission, email] = options.acl_grant.split(':')
            ok = ['READ', 'WRITE', 'READ_ACP', 'WRITE_ACP', 'FULL_CONTROL']
            if not permission.upper() in ok:
                args_valid = False
                log_error("Acl grant permission must be one of %s - "
                          "found %s" % (ok, permission))
            options.acl_grant = (permission.upper(), email)
        except ValueError:
            args_valid = False
            log_error("Acl grant must be in the form of PERMISSION:EMAIL "
                      "- found %s" % options.acl_grant)

    if (src_bucket_valid):
        logger.info("Source bucket name %s src path %s" %
                    (source_bucket_name, src_path))
        logger.info("Source uri: s3://%s/%s" % (source_bucket_name, src_path))
        logger.info("Source bucket readable: %s" % (src_bucket_readable))
    if (dst_bucket_valid):
        logger.info("Dest uri: s3://%s/%s" % (dst_bucket_name, dst_prefix))
        logger.info("Dest bucket readable: %s" % (dst_bucket_readable))
    if (work_bucket_valid):
        logger.info("Work uri: s3://%s/%s (if needed)" %
                    (work_bucket_name, work_prefix))

    if not args_valid:
        log_error("Resolve configuration errors and try again")
        sys.exit(global_return_code)

    print("Source uri: s3://%s/%s" % (src_bucket.name, src_path))
    print("Dest uri: s3://%s/%s" % (dst_bucket.name, dst_prefix))

    if not dst_bucket_readable:
        readable_dst_bucket = None
    dest = Dest(dst_bucket, readable_dst_bucket, dst_prefix)

    sources = []

    if options.file is not None:
        for f in options.file:
            ki = get_keyinfo(src_bucket, src_path + f)
            if ki is not None:
                sources.append(Source(src_bucket, ki, src_path, f, None))

    if options.files is not None:
        for line in open(os.path.expanduser(options.files)):
            if line.startswith('#'):
                continue
            f = line.rstrip('\r\n')
            ki = get_keyinfo(src_bucket, src_path + f)
            if ki is not None:
                sources.append(Source(src_bucket, ki, src_path, f, None))

    prefixes = options.prefix if options.prefix is not None else []
    if options.prefixes is not None:
        for line in open(os.path.expanduser(options.prefixes)):
            if line.startswith('#'):
                continue
            prefixes.append(line.rstrip('\r\n'))

    # if no files or prefixes were explicitly specified, then we copy the
    # entire bucket
    if len(sources) == 0 and len(prefixes) == 0:
        prefixes.append("")

    for p in prefixes:
        prefix = src_path + p
        # note: with boto, the list() command never throws a permission error -
        # but one may be thrown when you consume the list
        if prefix == "" or prefix == "*":
            keys = src_bucket.list()
            last_component = 0
        else:
            if prefix[-2:] == "/*":
                prefix = prefix[:-1]
                last_component = len(prefix)
            elif prefix[-1:] == "*":
                prefix = prefix[:-1]
                last_component = 1 + prefix.rfind("/")
            else:
                if prefix[-1:] != "/": prefix = prefix + "/"
                last_component = 1 + prefix[:-1].rfind("/")
            keys = src_bucket.list(prefix)
        try:
            for key in keys:
                ki = KeyInfo.fromKey(key)
                sources.append(Source(src_bucket,
                               ki,
                               prefix[:last_component],
                               ki.name[last_component:],
                               None))
        except boto.exception.S3ResponseError as e:
            if e.status == 403:
                if prefix == "":
                    log_error("Error: Access denied to list bucket "
                              "s3://%s" % src_bucket.name)
                else:
                    log_error("Error: Access denied to list bucket prefix "
                              "s3://%s/%s" % (src_bucket.name, prefix))
            else:
                raise

    if options.dry_run:
        # we do everything without threading if this is a dry run - easier to
        # follow the output
        options.num_threads = 0

    if len(sources) == 0:
        print "Nothing to copy: quitting"
        sys.exit(global_return_code)

    some_files_to_split = False
    for source in sources:
        if should_split(source.key_info):
            logger.info("Found file to split %s", str(source.key_info))
            some_files_to_split = True
            break

    splitresults = dict()
    if some_files_to_split:
        workdest = Dest(work_bucket, work_bucket, work_prefix)
        print("Work uri: s3://%s/%s" % (work_bucket.name, work_prefix))
        logger.info("Split phase")
        split_phase(sources, workdest, splitresults)

    # copy all the small files and the parts to the real target bucket
    # these are the (potential) cross-data-center copy operations

    logger.info("Copy phase")
    copy_phase(sources, splitresults, dest)

    # reassemble multipart files into single files

    if some_files_to_split:
        logger.info("Reassembly phase")
        reassembly_phase(splitresults, dest, workdest)
        final_phase(splitresults, dest, workdest)

    logger.info("Done")

    progress_update()
    sys.stdout.write("\n")

if __name__ == '__main__':
    return_code = main()
    sys.exit(global_return_code)
