#! /usr/bin/env python

import sys
import time
import traceback
import transaction
from persistent import Persistent
from ZODB.tests.StorageTestBase import zodb_pickle

from neo.lib.util import p64
from neo.lib.protocol import CellStates
from neo.tests import DB_PREFIX
from neo.tests.benchmark import BenchmarkRunner
from neo.tests.functional import NEOCluster

PARTITIONS = 16
TRANSACTIONS = 1024
OBJECTS = 1024
REVISIONS = 4
OBJECT_SIZE = 1024
CUT_AT = 0

def humanize(size):
    units = ['%.2f KB', '%.2f MB', '%2.f GB']
    unit = '%d bytes'
    while size >= 1024 and units:
        size /= 1024.0
        unit, units = units[0], units[1:]
    return unit % size

class DummyObject(Persistent):

    def __init__(self, data):
        self._data = None

class ReplicationBenchmark(BenchmarkRunner):
    """ Test replication process """

    def add_options(self, parser):
        add_option = parser.add_option
        add_option('', '--transactions', help="Total number of transactions")
        add_option('', '--objects', help="Total number of objects")
        add_option('', '--revisions', help="Number of revisions per object")
        add_option('', '--partitions', help="Number of partition")
        add_option('', '--object-size', help="Size of an object revision")
        add_option('', '--cut-at', help="Populate the destination up to this %")

    def load_options(self, options, args):
        transactions = int(options.transactions or TRANSACTIONS)
        objects = int(options.objects or OBJECTS)
        revisions = int(options.revisions or REVISIONS)
        if (objects * revisions) % transactions != 0:
            sys.exit('Invalid parameters (need multiples)')
        return dict(
            partitions = int(options.partitions or PARTITIONS),
            transactions = transactions,
            objects = objects,
            revisions = revisions,
            object_size = int(options.object_size or OBJECT_SIZE),
            cut_at = int(options.cut_at or CUT_AT),
        )

    def time_it(self, method, *args, **kw):
        start = time.time()
        method(*args, **kw)
        return time.time() - start

    def start(self):
        config = self._config
        # build a neo
        neo = NEOCluster(
            db_list=['%s_replication_%u' % (DB_PREFIX, i) for i in xrange(2)],
            clear_databases=True,
            partitions=config.partitions,
            replicas=1,
            master_count=1,
            verbose=False,
        )
        neo.start()
        p_time = r_time = None
        content = ''
        try:
            try:
                p_time = self.time_it(self.populate, neo)
                neo.expectOudatedCells(self._config.partitions)
                storage = neo.getStorageProcessList()[-1]
                storage.start()
                neo.expectRunning(storage, delay=0.1)
                print "Source storage populated in %.3f secs" % p_time
                r_time = self.time_it(self.replicate, neo) + 0.1
            except Exception:
                content = ''.join(traceback.format_exc())
        finally:
            neo.stop()
        return self.buildReport(p_time, r_time), content

    def replicate(self, neo):
        def number_of_oudated_cell():
            row_list = neo.neoctl.getPartitionRowList()[1]
            number_of_oudated = 0
            for row in row_list:
                for cell in row[1]:
                    if cell[1] == CellStates.OUT_OF_DATE:
                        number_of_oudated += 1
            return number_of_oudated
        end_time = time.time() + 3600
        while time.time() <= end_time and number_of_oudated_cell() > 0:
            time.sleep(1)
        if number_of_oudated_cell() > 0:
            raise Exception('Replication takes too long')

    def buildReport(self, p_time, r_time):
        add_status = self.add_status
        cut_at = self._config.cut_at
        objects = self._config.objects
        revisions = self._config.revisions
        object_size = self._config.object_size
        partitions = self._config.partitions
        objects_revisions = revisions * objects
        objects_space = objects_revisions * object_size
        add_status('Partitions', self._config.partitions)
        add_status('Transactions', self._config.transactions)
        add_status('Objects', objects)
        add_status('Revisions', revisions)
        add_status('Cut at', '%d%%' % cut_at)
        add_status('Object size', humanize(object_size))
        add_status('Objects space', humanize(objects_space))
        if p_time is None:
            return 'Populate failed'
        add_status('Population time', '%.3f secs' % p_time)
        if r_time is None:
            return 'Replication failed'
        bandwidth = objects_space / r_time
        add_status('Replication time', '%.3f secs' % r_time)
        add_status('Time per partition', '%.3f secs' % (r_time / partitions))
        add_status('Time per object', '%.3f secs' % (r_time / objects_revisions))
        add_status('Global bandwidth', '%s/sec' % humanize(bandwidth))
        summary = "%d%% of %s replicated at %s/sec" % (100 - cut_at,
            humanize(objects_space), humanize(bandwidth))
        return summary

    def populate(self, neo):
        print "Start populate"
        db, conn = neo.getZODBConnection(compress=False)
        storage = conn._storage
        cut_at = self._config.cut_at
        objects = self._config.objects
        transactions = self._config.transactions
        revisions = self._config.revisions
        objects_turn = objects / transactions
        objects_per_transaction = (objects * revisions) / transactions

        objects_revisions = objects * revisions
        base_oid = 1
        data = zodb_pickle(DummyObject("-" * self._config.object_size))
        prev = p64(0)
        progress = 0
        cutted = False
        for tidx in xrange(transactions):
            if not cutted and (100 * progress) / objects_revisions == cut_at:
                print "Cut at %d%%" % (cut_at, )
                neo.getStorageProcessList()[-1].stop()
                cutted = True
            txn = transaction.Transaction()
            txn.description = "Transaction %s" % tidx
            # print txn.description
            storage.tpc_begin(txn)
            for oidx in xrange(objects_per_transaction):
                progress += 1
                oid = base_oid + oidx
                storage.store(p64(oid), prev, data, '', txn)
                # print "  OID %d" % oid
            storage.tpc_vote(txn)
            prev = storage.tpc_finish(txn)
            if tidx % objects_turn == 1:
                base_oid += objects_per_transaction
        if not cutted:
            assert cut_at == 100
            neo.getStorageProcessList()[-1].stop()

def main(args=None):
    ReplicationBenchmark().run()

if __name__ == "__main__":
    main()

