# -*- coding: utf-8 -*-

import time
import uuid

from gmemcache import MemcacheConnection
from gmemcache.gmemcache import MemcachedKeyLengthError, MemcachedKeyTypeError, MemcachedValueLengthError
from gmemcache.serializer import BaseSerializer

from nose.tools import *

MEMCACHED_SERVER = '127.0.0.1:11211'

_conn = None
_max_key_length = 250
_max_value_length = (1000 ** 2)


def _setup_connection():
    global _conn
    _conn = MemcacheConnection([MEMCACHED_SERVER])


def _drop_connection():
    global _conn
    _conn.close()
    _conn = None


def test_open_lazy():
    conn = MemcacheConnection([MEMCACHED_SERVER], lazy=True)
    ok_(not conn.is_connected())
    conn.open()
    ok_(conn.is_connected())
    conn.close()


def test_close():
    conn = MemcacheConnection([MEMCACHED_SERVER])
    conn.open()
    ok_(conn.is_connected())
    conn.close()
    ok_(not conn.is_connected())


@with_setup(setup=_setup_connection, teardown=_drop_connection)
def test_get():
    key = uuid.uuid1().hex
    _conn.set(key, 'value')

    eq_('value', _conn.get(key))


@with_setup(setup=_setup_connection, teardown=_drop_connection)
def test_get_multi():
    key1 = uuid.uuid1().hex
    key2 = uuid.uuid1().hex
    key3 = uuid.uuid1().hex
    key4 = uuid.uuid1().hex
    _conn.set_multi({key1: 'value1', key2: 'value2', key3: 'value3'})

    eq_({key1: 'value1', key2: 'value2', key3: 'value3'},
        _conn.get_multi([key1, key2, key3]))

    eq_({key1: 'value1'},
        _conn.get_multi([key1, key4]))


@with_setup(setup=_setup_connection, teardown=_drop_connection)
def test_set():
    key = uuid.uuid1().hex
    _conn.set(key, 'value')
    eq_('value', _conn.get(key))


@with_setup(setup=_setup_connection, teardown=_drop_connection)
def test_set_with_lifetime():
    key = uuid.uuid1().hex
    _conn.set(key, 'value', lifetime=1)
    eq_('value', _conn.get(key))
    time.sleep(2)
    eq_(None, _conn.get(key))


@with_setup(setup=_setup_connection, teardown=_drop_connection)
def test_set_multi():
    key1 = uuid.uuid1().hex
    key2 = uuid.uuid1().hex
    key3 = uuid.uuid1().hex
    _conn.set_multi({key1: 'value1', key2: 'value2', key3: 'value3'})
    eq_({key1: 'value1', key2: 'value2', key3: 'value3'}, _conn.get_multi([key1, key2, key3]))


@with_setup(setup=_setup_connection, teardown=_drop_connection)
def test_set_multi_with_lifetime():
    key = uuid.uuid1().hex
    _conn.set_multi({key: 'value'}, lifetime=1)
    eq_('value', _conn.get(key))
    time.sleep(2)
    eq_(None, _conn.get(key))


@with_setup(setup=_setup_connection, teardown=_drop_connection)
def test_with_invalid_key_type():
    key = u'key'
    assert_raises(MemcachedKeyTypeError, _conn.get, key)
    assert_raises(MemcachedKeyTypeError, _conn.get_multi, [key])
    assert_raises(MemcachedKeyTypeError, _conn.set, key, 'value')
    assert_raises(MemcachedKeyTypeError, _conn.set_multi, {key: 'value'})


def test_with_exceeded_key_length():
    conn = MemcacheConnection([MEMCACHED_SERVER],
                               max_key_length=_max_key_length)

    key1 = '*' * (_max_key_length - 1)
    key2 = '*' * _max_key_length
    key3 = '*' * (_max_key_length + 1)

    ok_(conn.set(key1, 'value'))
    ok_(conn.set(key2, 'value'))
    ok_(conn.set_multi({key1: 'value', key2: 'value'}))

    assert_raises(MemcachedKeyLengthError, conn.set, key3, 'value')
    assert_raises(MemcachedKeyLengthError, conn.set_multi, {key3: 'value'})

    eq_('value', conn.get(key1))
    eq_('value', conn.get(key2))
    eq_({key1: 'value', key2: 'value'}, conn.get_multi([key1, key2]))
    assert_raises(MemcachedKeyLengthError, conn.get, key3)
    assert_raises(MemcachedKeyLengthError, conn.get_multi, [key3])

    conn.close()


def test_with_exceeded_value_length():
    class NoopSerializer(BaseSerializer):
        def serialize(self, value):
            return value

        def deserialize(self, value):
            return value

    conn = MemcacheConnection([MEMCACHED_SERVER],
                               serializer=NoopSerializer,
                               max_value_length=_max_value_length)

    value1 = '*' * (_max_value_length - 1)
    value2 = '*' * _max_value_length
    value3 = '*' * (_max_value_length + 1)

    ok_(conn.set('key1', value1))
    ok_(conn.set('key2', value2))
    ok_(conn.set_multi({'key1': value1, 'key2': value2}))

    assert_raises(MemcachedValueLengthError, conn.set, 'key3', value3)
    assert_raises(MemcachedValueLengthError, conn.set_multi, {'key3': value3})

    conn.close()


@with_setup(setup=_setup_connection, teardown=_drop_connection)
def test_flush_all():
    _conn.set_multi({'key1': 'value1',
                     'key2': 'value2'}, lifetime=1)
    ok_(_conn.flush_all())
    eq_({}, _conn.get_multi(['key1', 'key2']))
