import logging
import itertools
import random
from twisted.conch.client.knownhosts import KnownHostsFile

from twisted.internet import protocol, defer, reactor
from twisted.python import failure
from twisted.python.filepath import FilePath

from tron import ssh, eventloop
from tron.utils import twistedutils, collections


log = logging.getLogger(__name__)


# TODO: make these configurable
# We should also only wait a certain amount of time for a connection to be
# established.
CONNECT_TIMEOUT = 30

IDLE_CONNECTION_TIMEOUT = 3600

# We should also only wait a certain amount of time for a new channel to be
# established when we already have an open connection.  This timeout will
# usually get triggered prior to even a TCP timeout, so essentially it's our
# shortcut to discovering the connection died.
RUN_START_TIMEOUT = 20

# Love to run this, but we need to finish connecting to our node first
RUN_STATE_CONNECTING = 0

# We are connected and trying to open a channel to exec the process
RUN_STATE_STARTING = 5

# Process has been exec'ed, just waiting for it to exit
RUN_STATE_RUNNING = 10

# Process has exited
RUN_STATE_COMPLETE = 100


class Error(Exception):
    pass


class ConnectError(Error):
    """There was a problem connecting, run was never started"""
    pass


class ResultError(Error):
    """There was a problem retrieving the result from this run

    We did try to execute the command, but we don't know if it succeeded or
    failed.
    """
    pass


class NodePoolRepository(object):
    """A Singleton to store Node and NodePool objects."""

    _instance = None

    def __init__(self):
        if self._instance is not None:
            raise ValueError("NodePoolRepository is already instantiated.")
        super(NodePoolRepository, self).__init__()
        self.nodes = collections.MappingCollection('nodes')
        self.pools = collections.MappingCollection('pools')

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance

    def filter_by_name(self, node_configs, node_pool_configs):
        self.nodes.filter_by_name(node_configs)
        self.pools.filter_by_name(node_configs.keys() + node_pool_configs.keys())

    @classmethod
    def update_from_config(cls, node_configs, node_pool_configs, ssh_config):
        instance = cls.get_instance()
        ssh_options = ssh.SSHAuthOptions.from_config(ssh_config)
        known_hosts = KnownHosts.from_path(ssh_config.known_hosts_file)
        instance.filter_by_name(node_configs, node_pool_configs)

        for config in node_configs.itervalues():
            pub_key = known_hosts.get_public_key(config.hostname)
            instance.add_node(Node.from_config(config, ssh_options, pub_key))

        for config in node_pool_configs.itervalues():
            nodes = instance._get_nodes_by_name(config.nodes)
            pool  = NodePool.from_config(config, nodes)
            instance.pools.replace(pool)

    def add_node(self, node):
        self.nodes.replace(node)
        self.pools.replace(NodePool.from_node(node))

    def get_node(self, node_name, default=None):
        return self.nodes.get(node_name, default)

    def __contains__(self, node):
        return node.get_name() in self.pools

    def get_by_name(self, name, default=None):
        return self.pools.get(name, default)

    def _get_nodes_by_name(self, names):
        return [self.nodes[name] for name in names]

    def clear(self):
        self.nodes.clear()
        self.pools.clear()


class NodePool(object):
    """A pool of Node objects."""
    def __init__(self, nodes, name):
        self.nodes      = nodes
        self.disabled   = False
        self.name       = name or '_'.join(n.get_name() for n in nodes)
        self.iter       = itertools.cycle(self.nodes)

    @classmethod
    def from_config(cls, node_pool_config, nodes):
        return cls(nodes, node_pool_config.name)

    @classmethod
    def from_node(cls, node):
        return cls([node], node.get_name())

    def __eq__(self, other):
        return isinstance(other, NodePool) and self.nodes == other.nodes

    def __ne__(self, other):
        return not self == other

    def get_name(self):
        return self.name

    def next(self):
        """Return a random node from the pool."""
        return random.choice(self.nodes)

    def next_round_robin(self):
        """Return the next node cycling in a consistent order."""
        return self.iter.next()

    def disable(self):
        """Required for MappingCollection.Item interface."""
        self.disabled = True

    def get_by_hostname(self, hostname):
        for node in self.nodes:
            if node.hostname == hostname:
                return node

    def __str__(self):
        return "NodePool:%s" % self.name


class KnownHosts(KnownHostsFile):
    """Lookup host key for a hostname."""

    @classmethod
    def from_path(cls, file_path):
        if not file_path:
            return cls(None)
        return cls.fromPath(FilePath(file_path))

    def get_public_key(self, hostname):
        for entry in self._entries:
            if entry.matchesHost(hostname):
                return entry.publicKey
        log.warn("Missing host key for: %s", hostname)


def determine_fudge_factor(count, min_count=4):
    """Return a pseudo-random number. """
    fudge_factor = max(0.0, count - min_count)
    return random.random() * float(fudge_factor)


class RunState(object):
    def __init__(self, action_run):
        self.run = action_run
        self.state = RUN_STATE_CONNECTING
        self.deferred = defer.Deferred()
        self.channel = None


class Node(object):
    """A node is tron's interface to communicating with an actual machine.
    """

    def __init__(self, hostname, ssh_options, username=None, name=None, pub_key=None):

        # Host we are to connect to
        self.hostname = hostname

        # Username to connect to the remote machine under
        self.username = username

        # Identifier for UI
        self.name = name or hostname

        # SSH Options
        self.conch_options = ssh_options

        # The SSH connection we use to open channels on. If present, means we
        # are connected.
        self.connection = None

        # If present, means we are trying to connect
        self.connection_defer = None

        # Map of run id to instance of RunState
        self.run_states = {}

        self.idle_timer = eventloop.NullCallback
        self.disabled = False
        self.pub_key = pub_key

    @classmethod
    def from_config(cls, node_config, ssh_options, pub_key):
        return cls(
            node_config.hostname,
            ssh_options,
            username=node_config.username,
            name=node_config.name,
            pub_key=pub_key)

    def get_name(self):
        return self.name

    def disable(self):
        """Required for MappingCollection.Item interface."""
        self.disabled = True

    # TODO: wrap conch_options for equality
    def __eq__(self, other):
        if not isinstance(other, self.__class__):
            return False
        return (self.hostname == other.hostname and
                self.username == other.username and
                self.name == other.name and
                self.conch_options == other.conch_options and
                self.pub_key == other.pub_key)


    def __ne__(self, other):
        return not self == other

    # TODO: Test
    def submit_command(self, command):
        """Submit an ActionCommand to be run on this node. Optionally provide
        an error callback which will be called on error.
        """
        deferred = self.run(command)
        deferred.addErrback(command.handle_errback)
        return deferred

    def run(self, run):
        """Execute the specified run

        A run consists of a very specific set of interfaces which allow us to
        execute a command on this remote machine and return results.
        """
        log.info("Running %s for %s on %s", run.command, run.id, self.hostname)

        # When this run completes, for good or bad, we'll inform the caller by
        # calling 'succeed' or 'fail' on the run Since the definined interface
        # is on these specific callbacks, we won't bother returning the
        # deferred here. This allows the caller to not really care about
        # twisted specific stuff at all, all it needs to know is that one of
        # those functions will eventually be called back

        if run.id in self.run_states:
            raise Error("Run %s already running !?!", run.id)

        if self.idle_timer.active():
            self.idle_timer.cancel()

        self.run_states[run.id] = RunState(run)

        # TODO: have this return a runner instead of number
        fudge_factor = determine_fudge_factor(len(self.run_states))
        if fudge_factor == 0.0:
            self._do_run(run)
        else:
            log.info("Delaying execution of %s for %.2f secs", run.id, fudge_factor)
            eventloop.call_later(fudge_factor, self._do_run, run)

        # We return the deferred here, but really we're trying to keep the rest
        # of the world from getting too involved with twisted.
        return self.run_states[run.id].deferred

    def _do_run(self, run):
        """Finish starting to execute a run

        This step may have been delayed.
        """

        # Now let's see if we need to start this off by establishing a
        # connection or if we are already connected
        if self.connection is None:
            self._connect_then_run(run)
        else:
            self._open_channel(run)

    def _cleanup(self, run):
        # TODO: why set to None before deleting it?
        self.run_states[run.id].channel = None
        del self.run_states[run.id]

        if not self.run_states:
            self.idle_timer = eventloop.call_later(IDLE_CONNECTION_TIMEOUT,
                                                self._connection_idle_timeout)

    def _connection_idle_timeout(self):
        if self.connection:
            log.info("Connection to %s idle for %d secs. Closing.",
                     self.hostname, IDLE_CONNECTION_TIMEOUT)
            self.connection.transport.loseConnection()

    def _fail_run(self, run, result):
        """Indicate the run has failed, and cleanup state"""
        log.debug("Run %s has failed", run.id)
        if run.id not in self.run_states:
            log.warning("Run %s no longer tracked (_fail_run)", run.id)
            return

        # Add a dummy errback handler to prevent Unhandled error messages.
        # Unless somone is explicitly caring about this defer the error will
        # have been reported elsewhere.
        self.run_states[run.id].deferred.addErrback(lambda failure: None)

        cb = self.run_states[run.id].deferred.errback

        self._cleanup(run)

        log.info("Calling fail_run callbacks")
        run.exited(None)
        cb(result)

    def _connect_then_run(self, run):
        # Have we started the connection process ?
        if self.connection_defer is None:
            self.connection_defer = self._connect()

        def call_open_channel(arg):
            self._open_channel(run)
            return arg

        def connect_fail(result):
            log.warning("Cannot run %s, Failed to connect to %s",
                        run.id, self.hostname)
            self.connection_defer = None
            self._fail_run(run, failure.Failure(
                exc_value=ConnectError("Connection to %s failed" %
                                       self.hostname)))

        self.connection_defer.addCallback(call_open_channel)
        self.connection_defer.addErrback(connect_fail)

    def _service_stopped(self, connection):
        """Called when the SSH service has disconnected fully.

        We should be in a state where we know there are no runs in progress
        because all the SSH channels should have disconnected them.
        """
        assert self.connection is connection
        self.connection = None

        log.info("Service to %s stopped", self.hostname)

        for run_id, run in self.run_states.iteritems():
            if run.state == RUN_STATE_CONNECTING:
                # Now we can trigger a reconnect and re-start any waiting runs.
                self._connect_then_run(run)
            elif run.state == RUN_STATE_RUNNING:
                self._fail_run(run, None)
            elif run.state == RUN_STATE_STARTING:
                if run.channel and run.channel.start_defer is not None:

                    # This means our run IS still waiting to start. There
                    # should be an outstanding timeout sitting on this guy as
                    # well. We'll just short circut it.
                    twistedutils.defer_timeout(run.channel.start_defer, 0)
                else:
                    # Doesn't seem like this should ever happen.
                    log.warning("Run %r caught in starting state, but"
                                " start_defer is over.", run_id)
                    self._fail_run(run, None)
            else:
                # Service ended. The open channels should know how to handle
                # this (and cleanup) themselves, so if there should not be any
                # runs except those waiting to connect
                raise Error("Run %s in state %s when service stopped",
                            run_id, run.state)

    def _connect(self):
        # This is complicated because we have to deal with a few different
        # steps before our connection is really available for us:
        #  1. Transport is created (our client creator does this)
        #  2. Our transport is secure, and we can create our connection
        #  3. The connection service is started, so we can use it

        client_creator = protocol.ClientCreator(reactor,
            ssh.ClientTransport, self.username, self.conch_options, self.pub_key)
        create_defer = client_creator.connectTCP(self.hostname, 22)

        # We're going to create a deferred, returned to the caller, that will
        # be called back when we have an established, secure connection ready
        # for opening channels. The value will be this instance of node.
        connect_defer = defer.Deferred()
        twistedutils.defer_timeout(connect_defer, CONNECT_TIMEOUT)

        def on_service_started(connection):
            # Booyah, time to start doing stuff
            self.connection = connection
            self.connection_defer = None

            connect_defer.callback(self)
            return connection

        def on_connection_secure(connection):
            # We have a connection, but it might not be fully ready....
            connection.service_start_defer = defer.Deferred()
            connection.service_stop_defer = defer.Deferred()

            connection.service_start_defer.addCallback(on_service_started)
            connection.service_stop_defer.addCallback(self._service_stopped)
            return connection

        def on_transport_create(transport):
            transport.connection_defer = defer.Deferred()
            transport.connection_defer.addCallback(on_connection_secure)
            return transport

        def on_transport_fail(fail):
            log.warning("Cannot connect to %s", self.hostname)

        create_defer.addCallback(on_transport_create)
        create_defer.addErrback(on_transport_fail)

        return connect_defer

    def _open_channel(self, run):
        assert self.connection
        assert self.run_states[run.id].state < RUN_STATE_RUNNING

        self.run_states[run.id].state = RUN_STATE_STARTING

        chan = ssh.ExecChannel(conn=self.connection)

        chan.addOutputCallback(run.write_stdout)
        chan.addErrorCallback(run.write_stderr)
        chan.addEndCallback(run.done)

        chan.command = run.command
        chan.start_defer = defer.Deferred()
        chan.start_defer.addCallback(self._run_started, run)
        chan.start_defer.addErrback(self._run_start_error, run)

        chan.exit_defer = defer.Deferred()
        chan.exit_defer.addCallback(self._channel_complete, run)
        chan.exit_defer.addErrback(self._channel_complete_unknown, run)

        twistedutils.defer_timeout(chan.start_defer, RUN_START_TIMEOUT)

        self.run_states[run.id].channel = chan
        self.connection.openChannel(chan)

    def _channel_complete(self, channel, run):
        """Callback once our channel has completed it's operation

        This is how we let our run know that we succeeded or failed.
        """
        log.info("Run %s has completed with %r", run.id, channel.exit_status)
        if run.id not in self.run_states:
            log.warning("Run %s no longer tracked", run.id)
            return

        assert self.run_states[run.id].state < RUN_STATE_COMPLETE

        self.run_states[run.id].state = RUN_STATE_COMPLETE
        cb = self.run_states[run.id].deferred.callback
        self._cleanup(run)

        run.exited(channel.exit_status)
        cb(channel.exit_status)

    def _channel_complete_unknown(self, result, run):
        """Channel has closed on a running process without a proper exit

        We don't actually know if the run succeeded
        """
        log.error("Failure waiting on channel completion: %s", str(result))
        self._fail_run(run, failure.Failure(exc_value=ResultError()))

    def _run_started(self, channel, run):
        """Our run is actually a running process now, update the state"""
        log.info("Run %s started for %s", run.id, self.hostname)
        channel.start_defer = None
        assert self.run_states[run.id].state == RUN_STATE_STARTING
        self.run_states[run.id].state = RUN_STATE_RUNNING

        run.started()

    def _run_start_error(self, result, run):
        """We failed to even run the command due to communication difficulties

        Once all the runs have closed out we can try to reconnect.
        """
        log.error("Error running %s, disconnecting from %s: %s",
                  run.id, self.hostname, str(result))

        # We clear out the deferred that likely called us because there are
        # actually more than one error paths because of user timeouts.
        self.run_states[run.id].channel.start_defer = None

        self._fail_run(run, failure.Failure(
            exc_value=ConnectError("Connection to %s failed" % self.hostname)))

        # We want to hard hangup on this connection. It could theoretically
        # come back thanks to the magic of TCP, but something is up, best to
        # fail right now then limp along for and unknown amount of time.
        #self.connection.transport.connectionLost(failure.Failure())

    def __str__(self):
        return "Node:%s@%s" % (self.username or "<default>", self.hostname)

    def __repr__(self):
        return self.__str__()