# -*- coding: utf-8 -*-

import socket
import struct
import logging
from collections import defaultdict
from hash_ring import HashRing

import gsocketpool

from serializer import MsgpackSerializer

HEADER_SIZE = 24

STRUCT_HEADER = '!BBHBBHLLQ'
STRUCT_GET = STRUCT_HEADER + '%ds'
STRUCT_SET = STRUCT_HEADER + 'LL%ds%ds'
STRUCT_FLUSH = STRUCT_HEADER + 'L'

REQUEST_MAGIC = 0x80
RESPONSE_MAGIC = 0x81

STATUS_SUCCESS = 0x00

COMMAND_GET = 0x00
COMMAND_GETK = 0x0C
COMMAND_GETKQ = 0x0D
COMMAND_SET = 0x01
COMMAND_SETQ = 0x11
COMMAND_FLUSH = 0x08
COMMAND_FLUSHQ = 0x18


class MemcachedKeyError(Exception):
    pass


class MemcachedKeyLengthError(MemcachedKeyError):
    pass


class MemcachedKeyTypeError(MemcachedKeyError):
    pass


class MemcachedValueLengthError(Exception):
    pass


class MemcacheConnection(object):
    """Memcache connection.

    Usage:
        >>> from gmemcache import MemcacheConnection
        >>> client = MemcacheConnection(['127.0.0.1:11211'])
        >>> client.is_connected()
        True
        >>> client.set_multi({'key1': 'value1', 'key2': 'value2'})
        True
        >>> client.get_multi(['key1', 'key2', 'key3'])
        {'key1': u'value1', 'key2': u'value2'}
        >>> client.close()

    :param list hosts: Hostnames.
    :param int timeout: (optional) Timeout.
    :param bool lazy: (optional) If set to True, the socket connection is not
        established until you specifically call :func:`open() <gmemcache.MemcacheConnection.open>`.
    :param bool serializer: (optional) The class used to serialize the value to be cached.
        :class:`MsgpackSerializer <gmemcache.MsgpackSerializer>` is used by default.
    :param int max_key_length: (optional) The maximum length of the cache key.
    :param int max_value_length: (optional) The maximum length of the cache value.
    """

    def __init__(self, hosts, timeout=5, lazy=False,
                 serializer=MsgpackSerializer,
                 max_key_length=250, max_value_length=1000**2):
        self._hosts = hosts
        self._timeout = timeout
        self._serializer = serializer()
        self._max_key_length = max_key_length
        self._max_value_length = max_value_length

        self._sockets = None
        self._ring = HashRing(self._hosts)

        if not lazy:
            self.open()

    def open(self):
        """Opens a connection."""

        self._sockets = {}
        for host in self._hosts:
            self._sockets[host] = self._connect(host)

    def close(self):
        """Closes the connection."""

        try:
            for sock in self._sockets.values():
                sock.close()
        except:
            logging.exception('Failed to close the memcache connection')

        self._sockets = None

    def _connect(self, host):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(self._timeout)
        (server, port) = host.split(':')
        sock.connect((server, int(port)))

        return sock

    def is_connected(self):
        """Returns whether the connection has already been established.

        :rtype: bool
        """

        if self._sockets:
            return True
        else:
            return False

    def reconnect(self, host):
        """Reconnects the existing connection to the specified host.

        :param str host: The target host
        """

        try:
            self._sockets[host].close()
        except:
            logging.exception('Failed to close the memcache connection: %s' % (host,))

        try:
            self._sockets[host] = self._connect(host)
        except:
            logging.exception('Failed to connect to the memcached: %s' % (host,))

    def get(self, key):
        """Retreives a record with the specified key.

        :param str key: Cache key.
        """

        assert self._sockets is not None, 'The connection has not been created'

        self._validate_key(key)

        host = self._ring.get_node(key)
        sock = self._sockets[host]
        try:
            sock.sendall(struct.pack(STRUCT_GET % (len(key),),
                                     REQUEST_MAGIC,
                                     COMMAND_GET,
                                     len(key), 0, 0, 0, len(key), 0, 0, key))

            resp = self._get_response(sock)

        except:
            self.reconnect(host)
            raise

        (_, value) = struct.unpack('!L%ds' % (resp['bodylen'] - 4,), resp['content'])

        if resp['status'] == STATUS_SUCCESS:
            return self._deserialize(value)

        else:
            return None

    def get_multi(self, keys):
        """Retreives multiple records with the specified keys.

        :param list keys: Cache keys.
        :rtype: dict
        """

        assert self._sockets is not None, 'The connection has not been created'

        if not keys:
            return {}

        divided_keys = defaultdict(list)
        for key in keys:
            self._validate_key(key)
            divided_keys[self._ring.get_node(key)].append(key)

        msgs = defaultdict(str)
        for (host, host_keys) in divided_keys.iteritems():
            for (n, key) in enumerate(host_keys):
                if n != len(host_keys) - 1:
                    msgs[host] += struct.pack(STRUCT_GET % (len(key),),
                                              REQUEST_MAGIC,
                                              COMMAND_GETKQ,
                                              len(key), 0, 0, 0, len(key), 0, 0, key)
                else:
                    msgs[host] += struct.pack(STRUCT_GET % (len(key),),
                                              REQUEST_MAGIC,
                                              COMMAND_GETK,
                                              len(key), 0, 0, 0, len(key), 0, 0, key)

        failed_hosts = []
        for (host, msg) in msgs.iteritems():
            try:
                self._sockets[host].sendall(msg)
            except:
                logging.exception('An error has occurred while sending a request to memcached')
                failed_hosts.append(host)
                self.reconnect(host)
                continue

        ret = {}
        for host in msgs.iterkeys():
            if host in failed_hosts:
                continue

            sock = self._sockets[host]
            opcode = -1

            while opcode != COMMAND_GETK:
                try:
                    resp = self._get_response(sock)
                except:
                    logging.exception('An error has occurred while receiving a response from memcached')
                    self.reconnect(host)
                    break

                opcode = resp['opcode']

                if resp['status'] == STATUS_SUCCESS:
                    (_, key, value) = struct.unpack('!L%ds%ds' % (resp['keylen'], resp['bodylen'] - resp['keylen'] - 4),
                                                    resp['content'])
                    ret[key] = self._deserialize(value)

        return ret

    def set(self, key, value, lifetime=0):
        """Saves a record to the cache.

        :param str key: Cache key.
        :param value: Value to be cached.
        :param int lifetime: The number of seconds until the records will expire.
        :rtype: bool
        """

        assert self._sockets is not None, 'The connection has not been created'

        self._validate_key(key)

        host = self._ring.get_node(key)
        sock = self._sockets[host]

        packed_value = self._serialize(value)
        self._validate_value(packed_value)

        try:
            sock.sendall(struct.pack(STRUCT_SET % (len(key), len(packed_value)),
                                     REQUEST_MAGIC,
                                     COMMAND_SET,
                                     len(key),
                                     8, 0, 0, len(key) + len(packed_value) + 8, 0, 0, 0,
                                     lifetime, key, packed_value))

            resp = self._get_response(sock)

        except:
            self.reconnect(host)
            raise

        if resp['status'] == STATUS_SUCCESS:
            return True
        else:
            return False

    def set_multi(self, data, lifetime=0):
        """Saves multiple records to the cache.

        :param dict data: Records to be cached.
        :param int lifetime: The number of seconds until the records will expire.
        :rtype: bool
        """

        assert self._sockets is not None, 'The connection has not been created'

        if not data:
            return True

        divided_data = defaultdict(dict)
        for (key, value) in data.iteritems():
            self._validate_key(key)
            divided_data[self._ring.get_node(key)][key] = value

        msgs = defaultdict(str)
        for (host, host_data) in divided_data.iteritems():
            for (n, (key, value)) in enumerate(host_data.iteritems()):
                packed_value = self._serialize(value)
                self._validate_value(packed_value)

                if n != len(host_data) - 1:
                    msgs[host] += struct.pack(STRUCT_SET % (len(key), len(packed_value)),
                                              REQUEST_MAGIC,
                                              COMMAND_SETQ,
                                              len(key), 8, 0, 0, len(key) + len(packed_value) + 8, 0, 0, 0,
                                              lifetime, key, packed_value)
                else:
                    msgs[host] += struct.pack(STRUCT_SET % (len(key), len(packed_value)),
                                              REQUEST_MAGIC,
                                              COMMAND_SET,
                                              len(key), 8, 0, 0, len(key) + len(packed_value) + 8, 0, 0, 0,
                                              lifetime, key, packed_value)

        failed_hosts = []
        for (host, msg) in msgs.iteritems():
            try:
                self._sockets[host].sendall(msg)
            except:
                logging.exception('An error has occurred while sending a request to memcached')
                failed_hosts.append(host)
                self.reconnect(host)
                continue

        retval = True
        for host in msgs.iterkeys():
            if host in failed_hosts:
                continue

            sock = self._sockets[host]
            opcode = -1

            while opcode != COMMAND_SET:
                try:
                    resp = self._get_response(sock)
                except:
                    logging.exception('An error has occurred while receiving a response from memcached')
                    retval = False
                    self.reconnect(host)
                    break

                opcode = resp['opcode']
                if resp['status'] != STATUS_SUCCESS:
                    retval = False

        return retval

    def flush_all(self):
        """Disable all stored items.

        :rtype: bool
        """

        assert self._sockets is not None, 'The connection has not been created'

        for host in self._hosts:
            sock = self._sockets[host]
            try:
                sock.sendall(struct.pack(STRUCT_FLUSH,
                                         REQUEST_MAGIC,
                                         COMMAND_FLUSH,
                                         0, 0, 0, 0, 0, 0, 0, 0))


                resp = self._get_response(sock)

            except:
                self.reconnect(host)
                raise

            if resp['status'] == STATUS_SUCCESS:
                return True
            else:
                return False

    def _serialize(self, value):
        if self._serializer:
            return self._serializer.serialize(value)
        else:
            return value

    def _deserialize(self, value):
        if self._serializer:
            return self._serializer.deserialize(value)
        else:
            return value

    def _get_response(self, sock):
        header = self._read(HEADER_SIZE, sock)
        (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, cas) = struct.unpack(STRUCT_HEADER, header)

        assert magic == RESPONSE_MAGIC

        extra_content = None
        if bodylen:
            extra_content = self._read(bodylen, sock)

        return dict(status=status,
                    opcode=opcode,
                    keylen=keylen,
                    bodylen=bodylen,
                    content=extra_content)

    def _read(self, size, sock):
        value = ''
        while len(value) < size:
            data = sock.recv(size - len(value))
            if not data:
                raise IOError('Connection closed')
            value += data

        return value

    def _validate_key(self, key):
        if not isinstance(key, str):
            raise MemcachedKeyTypeError('Key must be str()')

        if len(key) > self._max_key_length:
            raise MemcachedKeyLengthError(
                'Key length must be less than %d' % self._max_key_length)

    def _validate_value(self, val):
        if len(val) > self._max_value_length:
            raise MemcachedValueLengthError(
                'Value length must be less than %d' % self._max_value_length
            )


class MemcachePoolConnection(MemcacheConnection, gsocketpool.Connection):
    """Memcache connection wrapper class for `gsocketpool <https://github.com/studio-ousia/gsocketpool>`_.

    Usage:
        >>> import gsocketpool.pool
        >>> from gmemcache import MemcachePoolConnection
        >>> client_pool = gsocketpool.pool.Pool(MemcachePoolConnection, dict(hosts=['127.0.0.1:11211']))
        >>> with client_pool.connection() as client:
        ...     client.set('key1', 'value1')
        ...     client.get('key1')
        ...
        True
        u'value1'

    :param list hosts: Hostnames.
    :param int timeout: (optional) Timeout.
    :param bool serializer: (optional) The class used to serialize the value to be cached.
        :class:`MsgpackSerializer <gmemcache.MsgpackSerializer>` is used by default.
    :param int max_key_length: (optional) The maximum length of the cache key.
    :param int max_value_length: (optional) The maximum length of the cache value.
    """

    def __init__(self, hosts, timeout=5, serializer=MsgpackSerializer,
                 max_key_length=250, max_value_length=1000**2):
        MemcacheConnection.__init__(self, hosts, timeout, lazy=False, serializer=serializer)
