# -*- coding: utf-8 -*-

import time
import socket
import threading
import SocketServer
import logging
import json

import constants
from base import CallBackMgr


def handle_input_data(request):
    """
    因为udp和tcp的处理实在一样，所以还是抽离出来吧
    """
    request.server.logger.debug('raw_data: %s', request.raw_data)

    try:
        request.json_data = json.loads(request.raw_data)
    except Exception, e:
        error = 'parse raw_data fail. e: %s, raw_data: %s' % (e, request.raw_data)
        request.server.logger.error(error)
        return request.server.make_rsp(ret=constants.RET_SYSTEM, error=error)

    if not isinstance(request.json_data, dict):
        error = 'json_data is not dict. json_data: %s' % request.json_data
        request.server.logger.error(error)
        return request.server.make_rsp(ret=constants.RET_SYSTEM, error=error)

    if not request.json_data.get('endpoint'):
        error = 'endpoint is not in json_data. json_data: %s' % request.json_data
        request.server.logger.error(error)
        return request.server.make_rsp(ret=constants.RET_SYSTEM, error=error)

    endpoint = request.json_data['endpoint']
    view_func = request.server.get_route_view_func(endpoint)
    if not view_func:
        error = 'endpoint is not valid. endpoint: %s' % endpoint
        request.server.logger.error(error)
        return request.server.make_rsp(ret=constants.RET_SYSTEM, error=error)

    # before_request_func_list
    for func in request.server.before_request_func_list:
        try:
            func(request)
        except Exception, e:
            request.server.logger.error('before_request_func_list raise exception. func: %s, e: %s, traceback: %s',
                                        func.__name__, e, __import__('traceback').format_exc())
    view_func_exc = None
    view_func_result = None

    try:
        view_func_result = view_func(request)
    except Exception, e:
        error = 'view_func raise exception. endpoint: %s, view_func: %s, e: %s, traceback: %s' % (
            endpoint, view_func.__name__, e, __import__('traceback').format_exc())
        request.server.logger.error(error)
        view_func_exc = e
        view_func_result = request.server.make_rsp(ret=constants.RET_SYSTEM,
                                                   error=error if request.server.debug else constants.ERROR_INTERNAL)

    # after_request_func_list
    for func in request.server.after_request_func_list:
        try:
            # 异常优先，以为一旦有异常，肯定没返回
            func(request, view_func_exc or view_func_result)
        except Exception, e:
            request.server.logger.error('after_request_func_list raise exception. func: %s, view_func_result: %s, e: %s, traceback: %s',
                                        func.__name__, view_func_result, e, __import__('traceback').format_exc())

    if view_func_result is None:
        # 如果是None，就代表不想返回数据
        return view_func_result

    if isinstance(view_func_result, dict):
        view_func_result = request.server.make_rsp(**view_func_result)

    if not isinstance(view_func_result, basestring):
        error = 'invalid result type. endpoint: %s, view_func: %s, result type: %s, result: %s' \
                % (endpoint, view_func, type(view_func_result), view_func_result)
        request.server.logger.error(error)
        return request.server.make_rsp(ret=constants.RET_SYSTEM,
                                       error=error if request.server.debug else constants.ERROR_INTERNAL)

    return view_func_result


class GourdSyncTcp(SocketServer.ThreadingTCPServer, CallBackMgr):

    # 允许端口复用
    allow_reuse_address = True
    terminator = None
    logger = None
    debug = False

    def __init__(self, log_name=None, backlog=None, terminator=None, request_handler_class=None):
        SocketServer.ThreadingTCPServer.__init__(self,
                                                 (None, None),
                                                 request_handler_class or GourdSyncTcpRequestHandler,
                                                 bind_and_activate=False)
        CallBackMgr.__init__(self)

        self.terminator = terminator or constants.TERMINATOR
        self.logger = logging.getLogger(log_name or constants.LOG_NAME)
        # 就是用的backlog，默认是5
        self.request_queue_size = backlog or constants.BACKLOG

    def run(self, host, port):
        self.server_address = (host, port)
        self.server_bind()
        self.server_activate()

        server_thread = threading.Thread(target=self.serve_forever)
        # Exit the server thread when the main thread terminates
        server_thread.daemon = True
        server_thread.start()

        # 因为daemon设置为true，所以不做while循环会直接退出
        # 而之所以把 daemon 设置为true，是为了防止进程不结束的问题
        while True:
            time.sleep(1)

    def make_rsp(self, *args, **kwargs):
        """
        生成rsp
        """
        return json.dumps(dict(*args, **kwargs)) + self.terminator


class GourdSyncTcpRequestHandler(SocketServer.StreamRequestHandler):
    raw_data = None
    json_data = None

    def handle(self):
        self.raw_data = self.rfile.readline()

        if not self.raw_data:
            # 说明客户端断掉链接了，直接会进入finish
            return

        # create_request_func_list
        for func in self.server.create_request_func_list:
            try:
                func(self)
            except Exception, e:
                self.server.logger.error('create_request_func_list raise exception. func: %s, e: %s, traceback: %s',
                                         func.__name__, e, __import__('traceback').format_exc())

        result = handle_input_data(self)

        if result is not None:
            self.wfile.write(result)
        # 因为wfile的bufsize设置为0了，所以不需要flush也可以
        # self.wfile.flush()
        # handle结束后socket会自己调用关闭，并且关闭也会发送数据，因为tcp是半双工的

    def finish(self):
        # teardown_request_func_list
        for func in self.server.teardown_request_func_list:
            try:
                func(self)
            except Exception, e:
                self.server.logger.error('teardown_request_func_list raise exception. func: %s, e: %s, traceback: %s',
                                         func.__name__, e, __import__('traceback').format_exc())

        SocketServer.StreamRequestHandler.finish(self)


class GourdSyncUdp(SocketServer.ThreadingUDPServer, CallBackMgr):
    logger = None
    debug = False

    def __init__(self, log_name=None, request_handler_class=None):
        SocketServer.ThreadingUDPServer.__init__(self,
                                                 (None, None),
                                                 request_handler_class or GourdSyncUdpRequestHandler,
                                                 bind_and_activate=False)
        CallBackMgr.__init__(self)
        self.logger = logging.getLogger(log_name or constants.LOG_NAME)

    def run(self, host, port):
        self.server_address = (host, port)
        self.server_bind()
        self.server_activate()

        server_thread = threading.Thread(target=self.serve_forever)
        # Exit the server thread when the main thread terminates
        server_thread.daemon = True
        server_thread.start()

        # 因为daemon设置为true，所以不做while循环会直接退出
        # 而之所以把 daemon 设置为true，是为了防止进程不结束的问题
        while True:
            time.sleep(1)

    def make_rsp(self, *args, **kwargs):
        """
        生成rsp
        """
        return json.dumps(dict(*args, **kwargs))


class GourdSyncUdpRequestHandler(SocketServer.BaseRequestHandler):
    socket = None
    raw_data = None
    json_data = None

    def handle(self):
        self.raw_data, self.socket = self.request

        # create_request_func_list
        for func in self.server.create_request_func_list:
            try:
                func(self)
            except Exception, e:
                self.server.logger.error('create_request_func_list raise exception. func: %s, e: %s, traceback: %s',
                                         func.__name__, e, __import__('traceback').format_exc())

        result = handle_input_data(self)

        if result is not None:
            self.socket.sendto(result, self.client_address)

    def finish(self):
        # teardown_request_func_list
        for func in self.server.teardown_request_func_list:
            try:
                func(self)
            except Exception, e:
                self.server.logger.error('teardown_request_func_list raise exception. func: %s, e: %s, traceback: %s',
                                         func.__name__, e, __import__('traceback').format_exc())

        SocketServer.BaseRequestHandler.finish(self)
