# -*- coding: utf-8 -*-

"""
tcp:
    传输的数据包格式:raw_data\n
udp:
    传输的数据包格式:raw_data
    因为并不通过末尾的符号做识别

解析之后的json:
{
    endpoint: 'get_name'        str
    **custom data
}

view_func 格式:
def func(request):
    pass

"""

import logging
import time
import json
import asyncore
import asynchat
import socket
import functools
import constants


def handle_input_data(request):
    """
    因为udp和tcp的处理实在一样，所以还是抽离出来吧
    """
    request.server.logger.debug('raw_data: %s', request._raw_data)

    try:
        request._json_data = json.loads(request._raw_data)
    except Exception, e:
        error = 'parse raw_data fail. e: %s, raw_data: %s' % (e, request._raw_data)
        request.server.logger.error(error)
        request.push(request.server.make_rsp(ret=constants.RET_SYSTEM, error=error))
        request.close()
        return

    if not isinstance(request.json_data, dict):
        error = 'json_data is not dict. json_data: %s' % request.json_data
        request.server.logger.error(error)
        request.push(request.server.make_rsp(ret=constants.RET_SYSTEM, error=error))
        request.close()
        return

    if not request.json_data.get('endpoint'):
        error = 'endpoint is not in json_data. json_data: %s' % request.json_data
        request.server.logger.error(error)
        request.push(request.server.make_rsp(ret=constants.RET_SYSTEM, error=error))
        request.close()
        return

    endpoint = request.json_data['endpoint']
    view_func = request.server.get_route_view_func(endpoint)
    if not view_func:
        error = 'endpoint is not valid. endpoint: %s' % endpoint
        request.server.logger.error(error)
        request.push(request.server.make_rsp(ret=constants.RET_SYSTEM, error=error))
        request.close()
        return

    # before_request_func_list
    for func in request.server.before_request_func_list:
        try:
            func(request)
        except Exception, e:
            request.server.logger.error('before_request_func_list raise exception. func: %s, e: %s, traceback: %s',
                                        func.__name__, e, __import__('traceback').format_exc())
    view_func_exc = False
    view_func_result = None

    try:
        view_func_result = view_func(request)
    except Exception, e:
        view_func_result = e
        error = 'view_func raise exception. endpoint: %s, view_func: %s, e: %s, traceback: %s' % (
            endpoint, view_func.__name__, e, __import__('traceback').format_exc())
        request.server.logger.error(error)
        request.push(error if request.server.debug else constants.ERROR_INTERNAL)
        view_func_exc = True

    # after_request_func_list
    for func in request.server.after_request_func_list:
        try:
            func(request, view_func_result)
        except Exception, e:
            request.server.logger.error('after_request_func_list raise exception. func: %s, view_func_result: %s, e: %s, traceback: %s',
                                        func.__name__, view_func_result, e, __import__('traceback').format_exc())

    # 如果view_func异常，就要关闭socket
    if view_func_exc:
        request.close()


def catch_exc(func):
    """
    catch住request处理中的异常，并且close request
    """
    @functools.wraps(func)
    def func_wrapper(request, *args, **kwargs):
        try:
            return func(request, *args, **kwargs)
        except Exception, e:
            error = 'exception occur. e: %s, tbk: %s' % (e, __import__('traceback').format_exc())
            request.server.logger.error(error)
            request.push(request.server.make_rsp(ret=constants.RET_SYSTEM,
                                                 error=error if request.server.debug else constants.ERROR_INTERNAL))
            request.close()

    return func_wrapper


class CallBackMgr(object):
    """
    专门做路由管理
    """

    rule_map = None
    create_request_func_list = None
    before_request_func_list = None
    after_request_func_list = None
    teardown_request_func_list = None

    def __init__(self):
        self.rule_map = dict()
        self.create_request_func_list = []
        self.before_request_func_list = []
        self.after_request_func_list = []
        self.teardown_request_func_list = []

    def add_route_rule(self, endpoint=None, view_func=None, **options):
        assert view_func is not None, 'expected view func if endpoint is not provided.'

        endpoint = endpoint or view_func.__name__
        if endpoint in self.rule_map and view_func != self.rule_map[endpoint]:
            raise Exception, 'repeat view_func for endpoint: %(endpoint)s, old_view_func:%(old_view_func)s, new_view_func: %(new_view_func)s' % dict(
                endpoint=endpoint,
                old_view_func=self.rule_map[endpoint].__name__,
                new_view_func=view_func.__name__,
            )

        self.rule_map[endpoint] = view_func

    def route(self, **options):
        def decorator(f):
            endpoint = options.pop('endpoint', None)
            self.add_route_rule(endpoint, f, **options)
            return f
        return decorator

    def get_route_view_func(self, endpoint):
        return self.rule_map.get(endpoint)

    def create_request(self, f):
        """
        请求建立成功后
        """
        self.create_request_func_list.append(f)
        return f

    def before_request(self, f):
        """
        请求解析为json成功后
        """
        self.before_request_func_list.append(f)
        return f

    def after_request(self, f):
        """
        执行完route对应的view_func后
        """
        self.after_request_func_list.append(f)
        return f

    def teardown_request(self, f):
        """
        请求closed之后
        """
        self.teardown_request_func_list.append(f)
        return f


class GourdTcp(asyncore.dispatcher, CallBackMgr):
    backlog = None
    terminator = None
    logger = None
    request_handler_class = None

    def __init__(self, log_name=None, terminator=None, backlog=None, request_handler_class=None):
        # 把map变成私有的
        asyncore.dispatcher.__init__(self, map=dict())
        CallBackMgr.__init__(self)
        self.logger = logging.getLogger(log_name or constants.LOG_NAME)
        self.terminator = terminator or constants.TERMINATOR
        self.backlog = backlog or constants.BACKLOG
        self.request_handler_class = request_handler_class or TcpRequestHandler

    def handle_accept(self):
        sock, address = self.accept()
        self.logger.debug('accept. socket:%s, address:%s', sock, address)
        self.request_handler_class(self, address, self.terminator, sock, self._map)

    def run(self, host, port):
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.set_reuse_addr()
        self.bind((host, port))
        self.listen(self.backlog)

        asyncore.loop(map=self._map)

    def make_rsp(self, *args, **kwargs):
        """
        生成rsp
        """
        return json.dumps(dict(*args, **kwargs)) + self.terminator


class TcpRequestHandler(asynchat.async_chat):
    _server = None
    _address = None
    _raw_data = None
    _json_data = None

    def __init__(self, server, address, terminator, sock, map=None):
        asynchat.async_chat.__init__(self, sock, map)
        self._server = server
        self._address = address
        # 用json做格式，结尾加上个 \r\n 就可以了。
        self.set_terminator(terminator)

        # create_request_func_list
        for func in self.server.create_request_func_list:
            try:
                func(self)
            except Exception, e:
                self.server.logger.error('create_request_func_list raise exception. func: %s, e: %s, traceback: %s',
                                         func.__name__, e, __import__('traceback').format_exc())

    def collect_incoming_data(self, data):
        # 直接用默认提供的方法就行
        self._collect_incoming_data(data)

    @catch_exc
    def found_terminator(self):
        # 这里会把incoming data 清空
        self._raw_data = self._get_data()
        handle_input_data(self)

    def close(self):
        # handle_close 只能处理对方关闭链接的情况，handle_close中也是调用了close
        # 对端关闭链接，或者自己关闭链接时，调用
        asynchat.async_chat.close(self)
        self.server.logger.debug('socket closed')

        # teardown_request_func_list
        for func in self.server.teardown_request_func_list:
            try:
                func(self)
            except Exception, e:
                self.server.logger.error('teardown_request_func_list raise exception. func: %s, e: %s, traceback: %s',
                                         func.__name__, e, __import__('traceback').format_exc())

    @property
    def server(self):
        """
        获取当前server
        """
        return self._server

    @property
    def address(self):
        return self._address

    @property
    def raw_data(self):
        return self._raw_data

    @property
    def json_data(self):
        """
        获取获取的json数据
        """
        return self._json_data


class GourdUdp(asyncore.dispatcher, CallBackMgr):
    logger = None
    request_handler_class = None

    def __init__(self, log_name=None, request_handler_class=None):
        # 把map变成私有的
        asyncore.dispatcher.__init__(self, map=dict())
        CallBackMgr.__init__(self)
        self.logger = logging.getLogger(log_name or constants.LOG_NAME)
        self.request_handler_class = request_handler_class or UdpRequestHandler

    @catch_exc
    def handle_read(self):
        raw_data, address = self.socket.recvfrom(constants.PACKAGE_SIZE)
        request = self.request_handler_class(self, address, raw_data)
        request.process()

    def writable(self):
        """
        因为发送数据是直接sendto了，所以可以减少很多写事件循环
        """
        return False

    def run(self, host, port):
        self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.bind((host, port))

        asyncore.loop(map=self._map)

    def make_rsp(self, *args, **kwargs):
        """
        生成rsp
        """
        return json.dumps(dict(*args, **kwargs))


class UdpRequestHandler(object):
    _server = None
    _address = None
    _raw_data = None
    _json_data = None

    def __init__(self, server, address, raw_data):
        super(UdpRequestHandler, self).__init__()
        self._server = server
        self._address = address
        self._raw_data = raw_data

    def process(self):
        # create_request_func_list
        for func in self.server.create_request_func_list:
            try:
                func(self)
            except Exception, e:
                self.server.logger.error('create_request_func_list raise exception. func: %s, e: %s, traceback: %s',
                                         func.__name__, e, __import__('traceback').format_exc())
        handle_input_data(self)

        # teardown_request_func_list
        for func in self.server.teardown_request_func_list:
            try:
                func(self)
            except Exception, e:
                self.server.logger.error('teardown_request_func_list raise exception. func: %s, e: %s, traceback: %s',
                                         func.__name__, e, __import__('traceback').format_exc())

    def push(self, data):
        """
        发送数据，为了tcp一致
        """
        try:
            return self.socket.sendto(data, self._address)
        except socket.error, e:
            self.server.logger.error('socket.error. e:%s', e)
        except Exception, e:
            self.server.logger.error('exception occur. e: %s, tbk: %s', e, __import__('traceback').format_exc())

    def close(self):
        # 只是为了保持接口一致而已
        pass

    @property
    def socket(self):
        # 这个也是为了解决接口兼容而已，外界不要动他比较好
        return self.server.socket

    @property
    def server(self):
        """
        获取当前server
        """
        return self._server

    @property
    def address(self):
        return self._address

    @property
    def raw_data(self):
        return self._raw_data

    @property
    def json_data(self):
        """
        获取获取的json数据
        """
        return self._json_data
