"""
# -*- coding: utf-8 -*-
#====================================================================================================================
#
# Copyright (C) 2013/2014 Laurent Champagnac
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#====================================================================================================================
"""

# Logger
from _ssl import PROTOCOL_TLSv1
import logging
import os
from threading import RLock, Lock
import signal
import traceback

import gevent
from gevent.event import Event
from gevent.server import StreamServer
from gevent.timeout import Timeout
from springpython.context import DisposableObject
from pythonsol.AtomicInt import AtomicInt
from pythonsol.Meter.MeterManager import MeterManager
from pythonsol.SolBase import SolBase
from pythonsol.TcpBase.TcpSocketManager import TcpSocketManager
from pythonsol.TcpServer.TcpServerConfig import TcpServerConfig
from pythonsol.TcpServer.TcpServerStat import TcpServerStat

SolBase.loggingInit()
logger = logging.getLogger("TcpServer")


class TcpServer(DisposableObject):
    """
    Fast tcp server, gevent (epoll) based.
    Option via TcpServerConfig to enable SSL.
    If enabled, SSL handshake mode is done AFTER accept (do_handshake_on_connect=False)
    """

    def __init__(self, tcpServerConfig):
        """
        Constructor.
        :param tcpServerConfig: The configuration.
        :return Nothing.
        """

        # Check
        if tcpServerConfig is None:
            logger.error("__init__ : tcpServerConfig is None")
            raise Exception("__init__ : tcpServerConfig is None")
        elif not isinstance(tcpServerConfig, TcpServerConfig):
            logger.error(
                "__init__ : tcpServerConfig is not a TcpServerConfig, class=%s",
                SolBase.getClassName(tcpServerConfig))
            raise Exception("__init__ : tcpServerConfig is not a TcpServerConfig")

        # Store
        self._tcpServerConfig = tcpServerConfig

        # Init =>

        # _isStarted :
        # - True if server is started, False is server is stopped.
        # - Start : try to start, if ok set to True
        # - Stop : stop server, stop client, and set to False
        # - Means : During stop ongoing, will be True
        self._isStarted = False

        # _isRunning :
        # - True if server is running, False is server is no more running
        # - Start : Set to true, try to start, if failed, set to False
        # - Stop : Set to False, stop server, stop client
        # - Means : During stop ongoing, will be False
        self._isRunning = False

        # Gevent StreamServer
        self._server = None

        self._forkPidList = list()

        # Client management (re-entrant lock)
        self._clientConnectedAtomicInt = AtomicInt()
        self._clientConnectedHash = dict()
        self._clientConnectedHashLock = RLock()

        # Check the TcpServerStats
        if MeterManager.get(TcpServerStat) is None:
            logger.warn("__init__ : TcpServerStat not registered in MeterManager, adding it now")
            MeterManager.put(TcpServerStat())

        # Lock for start/stop
        self.__stopStartLock = Lock()

        # Control variables
        self._effectiveControlIntervalMs = 0

        # Control init
        self.__setEffectiveControlIntervalMs()

        # Auto start
        if self._tcpServerConfig.autoStart is True:
            logger.info("__init__ : Auto-starting ON, starting now")
            self.startServer()
        else:
            logger.info("__init__ : Auto-starting OFF")

    #=====================================================
    # START / STOP : HIGH LEVEL
    #=====================================================

    def startServer(self):
        """
        Start
        :return Nothing.
        """

        with self.__stopStartLock:
            logger.info("startServer : starting")
            try:
                if self._isStarted is True:
                    logger.warn("startServer : already started, doing nothing")
                    return False

                # Running
                self._isRunning = True

                # Low level start
                self._startServer()

                # Start
                self._isStarted = True

                # Done
                logger.info("startServer : started, %s", SolBase.getCurrentPidsAsString())

                # Exit
                return True
            except Exception as e:
                logger.error("startServer : exception, ex=%s", SolBase.exToStr(e))
                # Failed, not more running
                self._isRunning = False

                # Raise
                raise

    def destroy(self):
        """
        For spring python. Just call stopServer().
        :return: Nothing.
        """
        logger.info("destroy : Entering, calling self.stopServer(), %s", SolBase.getCurrentPidsAsString())
        self.stopServer()

    def stopServer(self):
        """
        Stop server
        :return Nothing.
        """

        with self.__stopStartLock:
            logger.info("stopServer : stopping, %s", SolBase.getCurrentPidsAsString())
            try:
                # No more running
                self._isRunning = False

                # Check
                if self._isStarted is False:
                    logger.warn("stopServer : already stopped, doing nothing, %s", SolBase.getCurrentPidsAsString())
                    return

                # Low level stop
                try:
                    gevent.with_timeout(self._tcpServerConfig.stopServerTimeoutMs, self._stopServer)
                except Timeout:
                    logger.warn("Timeout while calling low level _stopServer, some socket may be stucked")

                # Stop
                self._isStarted = False

                # Done
                logger.info("stopServer : stopped, %s", SolBase.getCurrentPidsAsString())
            except Exception as e:
                logger.error("stopServer : exception, ex=%s, %s", SolBase.exToStr(e),
                             SolBase.getCurrentPidsAsString())
                raise

    #=====================================================
    # CONFIG ACCESS
    #=====================================================

    def getTcpServerConfig(self):
        """
        Return the configuration associated to the TcpServer.
        :return: A TcpServerConfig instance.
        """

        return self._tcpServerConfig

    #=====================================================
    # START / STOP : LOW LEVEL
    #=====================================================

    def _startServer(self):
        """
        Low level start
        :return Nothing.
        """

        # Allocate a server, and provide a connection callback
        logger.info("_startServer : listenAddr=%s", self._tcpServerConfig.listenAddr)
        logger.info("_startServer : listenPort=%s", self._tcpServerConfig.listenPort)
        logger.info("_startServer : sslEnable=%s", self._tcpServerConfig.sslEnable)
        logger.info("_startServer : sslKeyFile=%s", self._tcpServerConfig.sslKeyFile)
        logger.info("_startServer : sslCertificateFile=%s", self._tcpServerConfig.sslCertificateFile)

        logger.info("_startServer : childProcessCount=%s", self._tcpServerConfig.childProcessCount)

        logger.info("_startServer : sslHandshakeTimeOutMs=%s", self._tcpServerConfig.sslHandshakeTimeOutMs)

        logger.info("_startServer : onStopCallClientStopSynch=%s", self._tcpServerConfig.onStopCallClientStopSynch)

        logger.info("_startServer : socketAbsoluteTimeOutMs=%s", self._tcpServerConfig.socketAbsoluteTimeOutMs)
        logger.info("_startServer : socketRelativeTimeOutMs=%s", self._tcpServerConfig.socketRelativeTimeOutMs)
        logger.info("_startServer : socketMinCheckIntervalMs=%s", self._tcpServerConfig.socketMinCheckIntervalMs)

        logger.info("_startServer : _effectiveControlIntervalMs=%s", self._effectiveControlIntervalMs)

        logger.info("_startServer : stopClientTimeoutMs=%s", self._tcpServerConfig.stopClientTimeoutMs)
        logger.info("_startServer : stopServerTimeoutMs=%s", self._tcpServerConfig.stopServerTimeoutMs)

        logger.info("_startServer : clientFactory=%s", self._tcpServerConfig.clientFactory)

        if self._tcpServerConfig.sslEnable is False:
            # No SSL
            logger.info("_startServer : Starting in TCP/CLEAR mode")
            self._server = StreamServer(
                (self._tcpServerConfig.listenAddr,
                 self._tcpServerConfig.listenPort),
                self._onConnection)
        else:
            # SSL ON
            logger.info("_startServer : Starting in TCP/SSL mode")
            self._server = StreamServer(
                (self._tcpServerConfig.listenAddr,
                 self._tcpServerConfig.listenPort),
                self._onConnection,
                # SSL enabling
                keyfile=self._tcpServerConfig.sslKeyFile, certfile=self._tcpServerConfig.sslCertificateFile,
                # SSL handshake after accept
                do_handshake_on_connect=True,
                # TLS
                ssl_version=PROTOCOL_TLSv1,
                # Cipher
                #ciphers="RC4-MD5",
                # Args
            )

            self._server.min_delay = 0.0
            self._server.max_delay = 0.0


        # Startup
        if self._tcpServerConfig.childProcessCount <= 0:
            # Normal start-up
            logger.info("_startServer : Starting in NON-FORKED mode")
            self._server.start()
        else:
            # Child process startup : prestart
            logger.info("_startServer : Pre-starting in FORKED mode, subprocess=%s",
                        self._tcpServerConfig.childProcessCount)
            # GEVENT_RC1 fix : pre_start => init_socket
            self._server.init_socket()

            # Let's rock
            logger.info("_startServer : Forking gevent")
            for idx in range(self._tcpServerConfig.childProcessCount):
                # Fork gevent hub
                # GEVENT_RC1 fix : hub.fork => os.fork
                _forkPid = gevent.os.fork()

                logger.info("_startServer : Forking gevent hub done, idx=%s, forkPid=%s, %s", idx, _forkPid,
                            SolBase.getCurrentPidsAsString())

                # Check it
                if _forkPid == 0:
                    # We are in a child => exit this loop
                    SolBase.setMasterProcess(False)
                    break
                else:
                    # Master on
                    SolBase.setMasterProcess(True)

                    # Store child pid
                    logger.info("_startServer : Storing child _forkPid=%s", _forkPid)
                    self._forkPidList.append(int(_forkPid))

            # Start accepting now (parent and all sub-processes)
            logger.info("_startServer : Accepting now, %s", SolBase.getCurrentPidsAsString())
            self._server.start_accepting()

    def _stopServer(self):
        """
        Low level stop
        :return Nothing.
        """

        try:
            # If we have child, signal them now
            for pid in self._forkPidList:
                logger.info("_stopServer : Sending SIGTERM to pid=%s", pid)
                os.kill(pid, signal.SIGTERM)

            # Wait for exit
            for pid in self._forkPidList:
                logger.info("_stopServer : Waiting for pid=%s", pid)

                # Wait
                waitPid, waitStatus = os.waitpid(pid, 0)

                # Get result
                waitSignal = waitStatus & 0xff
                if waitSignal == 0:
                    waitCode = waitStatus > 8
                else:
                    waitCode = 0

                # Info
                logger.info(
                    "_stopServer : Waiting for pid=%s ok, waitStatus=%s, waitSignal=%s, waitCode=%s, waitPid=%s", pid,
                    waitStatus, waitSignal, waitCode, waitPid)

            # Stop (timeout = 5 seconds)
            self._server.stop(5)

            # Clear client
            self._removeAllClient()
        finally:
            # Reset
            self._server = None

    #=====================================================
    # INTERNAL CALLBACKS
    #=====================================================

    def _onConnection(self, socket, address):
        """
        Callback called upon client connection.
        :param socket:  The client socket.
        :param address: The client remove address:
        :return Nothing.
        """
        logger.debug("_onConnection : address=%s %s", address, SolBase.getCurrentPidsAsString())

        # Register a new session
        # This will start the read/write loop on client.
        localClient = self._registerClient(socket, address)

        # Check
        if localClient is None:
            logger.error("_onConnection : _registerClient returned none")

    #=====================================================
    # CONTROL INTERVAL HELPER
    #=====================================================

    def __setEffectiveControlIntervalMs(self):
        """
        Set the effective control interval in ms
        :return: Nothing.
        """

        # Get values
        valAbs = self._tcpServerConfig.socketAbsoluteTimeOutMs
        valRel = self._tcpServerConfig.socketRelativeTimeOutMs

        # If absolute and relative both lower then zero : nothing to do, socket has no limit
        if valAbs <= 0 and valRel <= 0:
            self._effectiveControlIntervalMs = 0
        else:
            # Here, one of those is greater than zero.
            # If all of them are greater than zero : we keep the minimum
            if valAbs > 0 and valRel > 0:
                valSch = min(valAbs, valRel)
            # Else, one of them if zero or lower : we keep the maximum
            else:
                valSch = max(valAbs, valRel)

            # We have valSch, which is our target control check interval in ms.
            # To avoid too low values (and too high check frequency), we use correct it with minimal check interval
            self._effectiveControlIntervalMs = max(valSch, self._tcpServerConfig.socketMinCheckIntervalMs)

    def getEffectiveControlIntervalMs(self):
        """
        Return the effective control interval in ms to apply to socket
        :return: An integer (millis)
        """

        return self._effectiveControlIntervalMs

    #=====================================================
    # CLIENT MANAGEMENT : REGISTER
    #=====================================================

    def _registerClient(self, socket, address):
        """
        Register a new client.
        :param socket:  The client socket.
        :param address: The client remove address:
        :return Return a TcpServerClientContext upon success, None upon failure.
        """

        try:
            logger.debug("_registerClient : entering")

            # Must be started
            if self._isStarted is False:
                logger.debug("_registerClient : not started, cannot process")
                return None

            # Allocate a new client context
            logger.debug("_registerClient : allocating newClient using factory")
            newClient = self._tcpServerConfig.clientFactory.getNewClientContext(self,
                                                                                self._clientConnectedAtomicInt.increment(),
                                                                                socket, address)

            # Hash id
            logger.debug("_registerClient : hashing newClient")
            with self._clientConnectedHashLock:
                logger.debug("_registerClient : hashing newClient (in lock)")
                self._clientConnectedHash[newClient.getClientId()] = newClient

            # Statistics
            logger.debug("_registerClient : populating statistics")
            MeterManager.get(TcpServerStat).clientConnected.increment()
            MeterManager.get(TcpServerStat).clientRegisterCount.increment()

            # Enable SSL if required and set handshake timeout
            if self._tcpServerConfig.sslEnable is True:
                newClient.setSslHandshakeAsynch(True, self._tcpServerConfig.sslHandshakeTimeOutMs,
                                                self._tcpServerConfig.debugWaitInSslMs)

            # Start the client
            logger.debug("_registerClient : starting client")
            newClient.start()

            # Log
            logger.debug("_registerClient : client started and hashed, id=%s, addr=%s", newClient.getClientId(),
                         newClient.getClientAddr())
            return newClient
        except Exception as e:
            # Error
            logger.warn("_registerClient : exToStr, ex=%s", SolBase.exToStr(e))

            # Statistics
            MeterManager.get(TcpServerStat).clientRegisterException.increment()

            # Close the socket in this case (should not cover mantis 1173)
            TcpSocketManager.safeCloseSocket(socket)
            return None

    #=====================================================
    # CLIENT MANAGEMENT : REMOVE (ASYNCH and SYNCH) and REMOVE ALL
    #=====================================================

    def _removeClientAsynch(self, clientId):
        """
        Remove a client, asynch.
        :param clientId: The client id.
        :return Nothing.
        """

        # Spawn
        logger.debug("_removeClientAsynch : entering, clientId=%s", clientId)

        # Signal event (mantis 1280)
        evt = Event()

        # Spawn
        gevent.spawn(self._removeClient, clientId, evt)

        # Switch (mantis 1280)
        SolBase.sleep(0)

        # And wait
        # Note : remove this wait do not impact unittest...
        logger.debug("_removeClientAsynch : waiting, clientId=%s", clientId)
        evt.wait()

        # Over
        logger.debug("_removeClientAsynch : done, clientId=%s", clientId)

    def _removeClient_stopInternal(self, oldClient, evt):
        """
        Remove internal
        :param oldClient: olclient
        :param evt: gevent.Event
        :type evt: gevent.Event
        :type oldClient: TcpServerClientContext
        """

        try:
            # Get
            clientId = oldClient.getClientId()

            # Stop the client r/w loops and close the sock
            logger.debug("_removeClient_stopInternal call, clientId=%s", clientId)
            oldClient.stopSynchInternal()
        except Exception as e:
            logger.warn("Ex=%s", SolBase.exToStr(e))
        finally:
            evt.set()

    def _removeClient_stopBusiness(self, oldClient, evt):
        """
        Remove internal
        :param oldClient: olclient
        :param evt: gevent.Event
        :type evt: gevent.Event
        :type oldClient: TcpServerClientContext
        """

        #-------------------------
        # Stop the client (biz call here)
        # We do NOT call if :
        # - service is stopping AND onStopCallClientStopSynch==False
        #-------------------------

        try:
            clientId = oldClient.getClientId()

            logger.debug("_removeClient_stopBusiness call, clientId=%s", clientId)

            if self._isRunning == True:
                #-------------------
                # Running, call
                #-------------------
                logger.debug("stopSynch call (_isRunning==%s), clientId=%s", self._isRunning, clientId)
                oldClient.stopSynch()
            elif self._isRunning == False and self._tcpServerConfig.onStopCallClientStopSynch == True:
                #-------------------
                # Not running + call ON : call
                #-------------------
                logger.debug("stopSynch call (_isRunning==%s + onStopCallClientStopSynch==%s), clientId=%s",
                             self._isRunning, self._tcpServerConfig.onStopCallClientStopSynch, clientId)
                oldClient.stopSynch()
            else:
                #-------------------
                # No call
                #-------------------
                logger.debug("stopSynch NOT CALLED (_isRunning==%s + onStopCallClientStopSynch==%s), clientId=%s",
                             self._isRunning, self._tcpServerConfig.onStopCallClientStopSynch, clientId)
                pass
        except Exception as e:
            logger.warn("Ex=%s", SolBase.exToStr(e))
        finally:
            evt.set()

    def _removeClient(self, clientId, evt):
        """
        Remove a client. Return a TcpServerClientContext upon success,
        :param clientId: The client id.
        :param evt: Event to signal
        :type evt: gevent.Event, None
        :return The removed TcpServerClientContext or None upon failure.
        """

        logger.debug("entering, clientId=%s", clientId)

        try:
            with self._clientConnectedHashLock:
                # Check
                if self._clientConnectedHash.has_key(clientId) is False:
                    # Note : This may occurs in some conditions (mantis 1246)
                    logger.debug("clientId not hashed, id=%s", clientId)
                    MeterManager.get(TcpServerStat).clientRemoveNotHashed.increment()
                    return None

                # Get (direct, we are already in lock)
                oldClient = self._getClientFromId(clientId)

                # Remove from hashmap
                logger.debug("un-hashing, clientId=%s", clientId)
                del (self._clientConnectedHash[clientId])

            #------------------------------
            # Out of lock : call async : BUSINESS
            #------------------------------
            try:
                localEvt = Event()
                g1 = gevent.spawn(self._removeClient_stopBusiness, oldClient, localEvt)
                SolBase.sleep(0)

                localEvt.wait(self._tcpServerConfig.stopClientTimeoutMs / 1000.0)

                if localEvt.isSet() == False:
                    # Flush out warning
                    s = "Greenlet dump, g={0}, frame={1}".format(g1, ''.join(traceback.format_stack(g1.gr_frame)))

                    # Cleanup
                    s = s.replace("\n", " # ")
                    while s.find("  ") >= 0:
                        s = s.replace("  ", " ")

                    # Error logs
                    logger.error("Timeout in _removeClient_stopBusiness clientId=%s, stack=%s", clientId, s)

                    # Kill
                    g1.kill(block=True)

                    # Stat
                    MeterManager.get(TcpServerStat).clientRemoveTimeOutBusiness.increment(1)
            except Exception as e:
                logger.warn("Exception in _removeClient_stopBusiness clientId=%s, ex=%s", clientId, SolBase.exToStr(e))

            #------------------------------
            # Out of lock : call async : INTERNAL
            #------------------------------
            try:
                localEvt = Event()
                g2 = gevent.spawn(self._removeClient_stopInternal, oldClient, localEvt)
                SolBase.sleep(0)

                localEvt.wait(self._tcpServerConfig.stopClientTimeoutMs / 1000.0)

                if localEvt.isSet() == False:
                    # Flush out warning
                    s = "Greenlet dump, g={0}, frame={1}".format(g2, ''.join(traceback.format_stack(g2.gr_frame)))

                    # Cleanup
                    s = s.replace("\n", " # ")
                    while s.find("  ") >= 0:
                        s = s.replace("  ", " ")

                    # Error logs
                    logger.error("Timeout in _removeClient_stopInternal clientId=%s, stack=%s", clientId, s)

                    # Kill
                    g2.kill(block=True)

                    # Stat
                    MeterManager.get(TcpServerStat).clientRemoveTimeOutInternal.increment(1)
            except Exception as e:
                logger.warn("Exception in _removeClient_stopInternal clientId=%s, ex=%s", clientId, SolBase.exToStr(e))

            # Statistics
            MeterManager.get(TcpServerStat).clientConnected.increment(-1)
            MeterManager.get(TcpServerStat).clientRemoveCount.increment()

            # Log
            logger.debug("client removed, id=%s, addr=%s", oldClient.getClientId(), oldClient.getClientAddr())

            return oldClient
        except Exception as e:
            # Error
            logger.warn("Exception, ex=%s", SolBase.exToStr(e))

            # Statistics
            MeterManager.get(TcpServerStat).clientRemoveException.increment()

            return None
        finally:
            if evt:
                evt.set()

    def _removeAllClient(self):
        """
        Remove all clients.
        :return nothing.
        """
        try:
            # Pass through all client and remove all
            for clientId in self._clientConnectedHash.keys():
                self._removeClient(clientId, None)
        except Exception as e:
            # Error
            logger.warn("_removeAllClient : Exception, ex=%s", SolBase.exToStr(e))

    #=====================================================
    # CLIENT MANAGEMENT : GET FROM HASHMAP
    #=====================================================

    def _getClientFromId(self, clientId):
        """
        Get a client from id. Return None if not found.
        :param clientId: The client Id
        :return A TcpServerClientContext or None if not found.
        """
        with self._clientConnectedHashLock:
            # Check
            if self._clientConnectedHash.has_key(clientId) is False:
                return None
            else:
                return self._clientConnectedHash[clientId]








