from distprocessing.scifi import *
from distprocessing import monkey
monkey.patch_multiprocessing()

import itertools
from multiprocessing import (
    managers,
    pool,
    Process,
)
import Queue
import socket
import pickle

from distprocessing.utils import slot_count
from distprocessing.exceptions import ServerUnreachableException


__all__ = ['Cluster', 'WorkerServer']


DEFAULT_AUTHKEY = 'default_authkey'


class WorkerServer(managers.Server):
    
    def __init__(self, address, authkey=None, *args, **kwargs):
        if authkey is None:
            authkey = DEFAULT_AUTHKEY
        managers.Server.__init__(self, HostManager._registry, address, authkey=authkey,
                                    serializer='pickle', *args, **kwargs)


class HostManager(managers.SyncManager):

    def __init__(self, address, authkey):
        managers.SyncManager.__init__(self, address, authkey)
        self._name = 'Host-unknown'

    def Process(self, group=None, target=None, name=None, args=(), kwargs={}):
        data = pickle.dumps((target, args, kwargs))
        proc = self._RemoteProcess(data)
        if name is None:
            temp = self._name.split('Host-')[-1] + '/Process-%s'
            name = temp % ':'.join(map(str, proc.get_identity()))
        proc.name = name
        proc.exitcode = None
        return proc

    @classmethod
    def from_address(cls, address, authkey):
        manager = cls(address, authkey)
        manager._state.value = managers.State.STARTED
        manager._name = 'Host-%s:%s' % manager.address
        return manager

    def __repr__(self):
        return '<Host(%s)>' % self._name


class RemoteProcess(Process):
    
    def __init__(self, data):
        Process.__init__(self)
        self._data = data

    def _bootstrap(self):
        self._target, self._args, self._kwargs = pickle.loads(self._data)
        return Process._bootstrap(self)

    def get_identity(self):
        return self._identity


class DistributedPool(pool.Pool):

    def __init__(self, cluster, processes=None, initializer=None, initargs=()):
        self._cluster = cluster
        self.Process = cluster.Process
        try:
            pool.Pool.__init__(self, processes or len(cluster),
                               initializer, initargs)
        except socket.error:
            raise ServerUnreachableException('Make sure worker host is up.')

    def _setup_queues(self):
        self._inqueue = self._cluster._SettableQueue()
        self._outqueue = self._cluster._SettableQueue()
        self._quick_put = self._inqueue.put
        self._quick_get = self._outqueue.get

    @staticmethod
    def _help_stuff_finish(inqueue, task_handler, size):
        inqueue.set_contents([None] * size)


class Cluster(managers.SyncManager):

    def __init__(self, clienthost, hostlist, authkey=None):
        managers.SyncManager.__init__(self, address=clienthost)
        self._hostlist = [Host(address) for address in hostlist]
        self._authkey = authkey or DEFAULT_AUTHKEY

    def start(self):
        managers.SyncManager.start(self)

        for host in self._hostlist:
            host.manager = HostManager.from_address(host.hostname, self._authkey)
            host.Process = host.manager.Process
            host.slots = host.slots or slot_count()

        self._slotlist = [Slot(host) for host in self._hostlist for i in range(host.slots)]
        self._slot_iterator = itertools.cycle(self._slotlist)
        self._base_shutdown = self.shutdown
        del self.shutdown

    def shutdown(self):
        self._base_shutdown()

    def Process(self, group=None, target=None, name=None, args=(), kwargs={}):
        slot = next(self._slot_iterator)
        return slot.Process(group=group, target=target, name=name, args=args, kwargs=kwargs)

    def Pool(self, processes=None, initializer=None, initargs=()):
        return DistributedPool(self, processes, initializer, initargs)

    def __getitem__(self, i):
        return self._slotlist[i]

    def __len__(self):
        return len(self._slotlist)

    def __iter__(self):
        return iter(self._slotlist)


class SettableQueue(Queue.Queue):
    
    def empty(self):
        return not self.queue
        
    def full(self):
        return self.maxsize > 0 and len(self.queue) == self.maxsize
        
    def set_contents(self, contents):
        self.not_empty.acquire()
        try:
            self.queue.clear()
            self.queue.extend(contents)
            self.not_empty.notifyAll()
        finally:
            self.not_empty.release()


class Slot(object):
    def __init__(self, host):
        self.host = host
        self.Process = host.Process


class Host(object):
    
    def __init__(self, hostname, slots=None):
        self.hostname = hostname
        self.slots = slots
        

HostManager.register('_RemoteProcess', RemoteProcess)
Cluster.register('_SettableQueue', SettableQueue)
