#!/usr/bin/env python

import os, logging

__author__ = "Andrew Butterfield"
__copyright__ = "Copyright 2007-2009, The Cogent Project"
__credits__ = ["Andrew Butterfield", "Peter Maxwell", "Gavin Huttley",
                "Matthew Wakefield", "Edward Lang"]
__license__ = "GPL"
__version__ = "1.3.1"
__maintainer__ = "Gavin Huttley"
__email__ = "Gavin Huttley"
__status__ = "Production"

LOG = logging.getLogger('cogent')

# A flag to control if excess CPUs are worth a warning.
inefficiency_forgiven = False

class _FakeCommunicator(object):
    """Looks like a 1-cpu MPI communicator, but isn't"""
    rank = 0
    size = 1
    
    def split(self, colour):
        return (self, self)
    def sum(self, value, dest=None):
        return value
    def max(self, value, dest=None):
        return value
    def broadcast_obj(self, obj, source):
        return obj
    def broadcast(self, array, source):
        pass
    def barrier(self):
        pass
    

if os.environ.get('DONT_USE_MPI', 0):
    mpi = None
else:
    try:
        import mpi
    except ImportError:
        mpi = None
    else:
        LOG.info('MPI: %s processors' % mpi.world.size)
        if mpi.world.size == 1:
            mpi = None
if mpi is None:
    LOG.info('Not using MPI')
    def get_processor_name():
        return os.environ.get('HOSTNAME', 'one')
    _ParallelisationStack = [_FakeCommunicator()]
else:
    get_processor_name = mpi.get_processor_name
    _ParallelisationStack = [mpi.world]

def push(context):
    _ParallelisationStack.append(context)

def pop(context=None):
    context2 = _ParallelisationStack.pop()
    if context is not None:
        assert context2 is context
    return context2

def sync_random(r):
    if _ParallelisationStack[-1].size > 1:
        state = _ParallelisationStack[-1].broadcast_obj(r.getstate(), 0)
        r.setstate(state)

def getCommunicator():
    return _ParallelisationStack[-1]

def getSplitCommunicators(jobs):
    comm = getCommunicator()
    assert jobs > 0
    group_count = min(jobs, comm.size)
    if group_count == 1:
        next = _FakeCommunicator()
        sub = comm
    elif group_count == comm.size:
        next = comm
        sub = _FakeCommunicator()
    else:
        next = comm.split(comm.rank // group_count)
        sub = comm.split(comm.rank % group_count)
    return (next, sub)


# These two classes should be one simple generator definition,
# but can't wrap a 'yield' in a 'try ... finally' and I want it to
# be safe for loops with 'break' etc.
class _ParallelIter(object):
    def __init__(self, values, cpus, leftover):
        self.values = values
        self.leftover = leftover
        self.cpus = cpus
        self.i = cpus.rank
        push(self.leftover)
    
    def next(self):
        self.leftover.barrier()
        if self.i >= len(self.values):
            raise StopIteration
        result = self.values[self.i]
        self.i += self.cpus.size
        return result
    
    def __del__(self):
        self.cpus.barrier()
        pop(self.leftover)
    

class localShareOf(object):
    """For task in localShareOf(tasks): ..."""
    def __init__(self, values):
        self.values = values
        (self.cpus, self.leftover) = getSplitCommunicators(len(values))
    
    def __iter__(self):
        return _ParallelIter(self.values, self.cpus, self.leftover)
    

class ParaRandom:
    """Converts any random number generator with a .random() method
    into an MPI safe parallel random number generator.
    This relies on ParaRNG being passed the correct number of processes and rank.
    Internally ParaRNG assigns a phase for each process so that process n
    will always get the n th random number in the series.
    Without this method most random number generators will generate the same
    series on SMP machines and some MPI clusters.
    Can safely be used on itself to provide ParaRNG in nested parallelism.
    
    Warning: accessing the random number generator passed to this function after
    passing and without resetting the seed could generate duplicate values to
    those previously generated by this class or disrupt the phasing
    
    Arguments:
        o    random_number - a random number generator with a .random() method
             that generates a value between 0.0 and 1.0
        i    num_proc - the number of processes
        i    rank - the current processers rank
    
    """
    def __init__(self, random_number, num_proc = 1, rank = 0 ):
        self._rng = random_number
        self._num_proc = num_proc
        self._rank = rank
        #set the initial position in the random number series
        for i in range(self._rank):
            self._rng.random()
    
    def random(self):
        #get the current random number in the series
        r = self._rng.random()
        #advance by the number of processors to preposition for next call
        for i in range(self._num_proc):
            self._rng.random()
        return r
    
    def seed(self, arg):
        self._rng.seed(arg)
    

output_cpu = getCommunicator().rank == 0
