# -----------------------------------------
# Sextant
# Copyright 2014, Ensoft Ltd.
# Author: James Harkin, using work from Patrick Stevens and James Harkin
# -----------------------------------------
# Invokes Sextant and argparse

from __future__ import absolute_import, print_function

import io
import sys
import random
import socket
import logging
import logging.config
import argparse
import requests
import contextlib
import subprocess

try:
    from urllib import parse
except ImportError:  # fall back to Python 2
    import urlparse as parse

from . import query
from . import db_api
from . import update_db
from . import environment

config = environment.load_config()


def _displayable_url(args):
    """
    Return the URL specified by the user for Sextant to look at.

    This is needed because we may be using SSH (that is, Sextant sees a
    localhost port for the Neo4j server, which ssh is forwarding somewhere
    else); in this case, we want to display the URL that the user thinks
    Sextant is using, rather than the URL it's actually using.

    :param args: the argparse object Sextant was invoked with
    :return: the URL the Sextant invoker expects Sextant to be using

    """
    try:
        if args.display_neo4j:
            return args.display_neo4j
    except AttributeError:
        return args.remote_neo4j

    return args.remote_neo4j


# Beginning of functions which handle the actual invocation of Sextant

def _start_web(args):
    # Don't import at top level - makes twisted dependency semi-optional,
    # allowing non-web functionality to work with Python 3.
    if sys.version_info[0] == 2:
        from .web import server
    else:  # twisted won't be available - Python 2 required
        logging.error('Web server must be run on Python 2.')
        return
    logging.info("Serving site on port {}".format(args.port))

    # server is .web.server, imported a couple of lines ago
    server.serve_site(input_database_url=args.remote_neo4j, port=args.port)


def _audit(args):
    try:
        audited = query.audit(args.remote_neo4j)
    except requests.exceptions.ConnectionError as e:
        msg = 'Connection error to server {url}: {exception}'
        logging.error(msg.format(url=_displayable_url(args)), exception=e)

    if not audited:
        location = _displayable_url(args)
        logging.warning('No programs on database at {}.'.format(location))
    else:
        for program in audited:
            st = ('Program {progname} with {numfuncs} functions '
                  'uploaded by {uploader} (id {uploaderid}) on {date}.')
            print(st.format(progname=program.program_name,
                            numfuncs=program.number_of_funcs,
                            uploader=program.uploader,
                            uploaderid=program.uploader_id,
                            date=program.date))


def _add_program(args):
    try:
        alternative_name = args.name_in_db[0]
    except TypeError:
        alternative_name = None

    not_object_file = args.not_object_file
    # the default is "yes, this is an object file" if not-object-file was
    # unsupplied

    try:
        update_db.upload_program(file_path=args.input_file,
                                 db_url=args.remote_neo4j,
                                 alternative_name=alternative_name,
                                 not_object_file=not_object_file,
                                 display_url=_displayable_url(args))
    except requests.exceptions.ConnectionError as e:
        msg = 'Connection error to server {}: {}'
        logging.error(msg.format(_displayable_url(args), e))

    except IOError as e:
        # in case of Python 2, where FileNotFoundError is undefined
        # note: ConnectionError subclasses IOError, so must come first
        logging.error('Input file {} was not found.'.format(args.input_file[0]))
        logging.error(e)
        logging.debug(e, exc_info=True)
    except ValueError as e:
        logging.error(e)


def _delete_program(namespace):
    update_db.delete_program(namespace.program_name,
                             namespace.remote_neo4j)


def _make_query(namespace):
    arg1 = None
    arg2 = None
    try:
        arg1 = namespace.funcs[0]
        arg2 = namespace.funcs[1]
    except TypeError:
        pass
    except IndexError:
        pass

    try:
        program_name = namespace.program[0]
    except TypeError:
        program_name = None

    try:
        suppress_common = namespace.suppress_common[0]
    except TypeError:
        suppress_common = False

    query.query(remote_neo4j=namespace.remote_neo4j,
                display_neo4j=_displayable_url(namespace),
                input_query=namespace.query,
                program_name=program_name,
                argument_1=arg1, argument_2=arg2,
                suppress_common=suppress_common)

# End of functions which invoke Sextant

def parse_arguments():
    """
    Parses the command-line arguments to Sextant.

    The resulting object, result, has a .func property, which is a method to be
    called with result as its only parameter. This .func method runs whichever
    of Sextant's functionality is appropriate.
    :return: namespace summarising the arguments

    """

    argumentparser = argparse.ArgumentParser(prog='sextant', usage='sextant', description="Invoke part of the SEXTANT program")
    subparsers = argumentparser.add_subparsers(title="subcommands")

    #set what will be defined as a "common function"
    db_api.set_common_cutoff(config.common_cutoff)

    parsers = dict()

    # add each subparser in turn to the parsers dictionary

    parsers['add'] = subparsers.add_parser('add-program',
                                           help="add a program to the database")
    parsers['add'].add_argument('input_file', metavar="FILE_NAME",
                                help="name of file to be put into database",
                                type=str)
    parsers['add'].add_argument('--name-in-db', metavar="PROGRAM_NAME",
                                help="string to store this program under", type=str,
                                nargs=1)
    parsers['add'].add_argument('--not-object-file',
                                help='default False, if the input file is an '
                                     'object to be disassembled',
                                action='store_true')

    parsers['delete'] = subparsers.add_parser('delete-program',
                                              help="delete a program from the database")
    parsers['delete'].add_argument('program_name', metavar="PROG_NAME",
                                   help="name of program as stored in the database",
                                   type=str)

    parsers['query'] = subparsers.add_parser('query',
                                             help="make a query of the database")
    parsers['query'].add_argument('query', metavar="QUERY",
                                  help="functions-calling, functions-called-by, "
                                       "all-call-paths, whole-program, "
                                       "shortest-call-path, programs or "
                                       "functions; if the latter, "
                                       "supply argument --program",
                                  type=str)
    parsers['query'].add_argument('--program', metavar="PROG_NAME",
                                  help="name of program as stored in the database; "
                                       "required for all queries except 'programs'",
                                  type=str, nargs=1)
    parsers['query'].add_argument('--funcs', metavar='FUNCS',
                                  help='functions to pass to the query',
                                  type=str, nargs='+')
    parsers['query'].add_argument('--suppress-common', metavar='BOOL',
                                  help='suppress commonly called functions (True or False)',
                                  type=str, nargs=1)

    parsers['web'] = subparsers.add_parser('web', help="start the web server")
    parsers['web'].add_argument('--port', metavar='PORT', type=int,
                                help='port on which to serve Sextant Web',
                                default=config.web_port)

    parsers['audit'] = subparsers.add_parser('audit', help='view usage of Sextant')

    for key in parsers:
        parsers[key].add_argument('--remote-neo4j', metavar="URL",
                                  help="URL of neo4j server", type=str,
                                  default=config.remote_neo4j)
        parsers[key].add_argument('--use-ssh-tunnel', metavar="BOOL", type=str,
                                  help="whether to SSH into the remote server,"
                                       "True/False",
                                  default=str(config.use_ssh_tunnel))
        parsers[key].add_argument('--ssh-user', metavar="NAME", type=str,
                                  help="username to use as remote SSH name",
                                  default=str(config.ssh_user))

    parsers['audit'].set_defaults(func=_audit)
    parsers['web'].set_defaults(func=_start_web)
    parsers['add'].set_defaults(func=_add_program)
    parsers['delete'].set_defaults(func=_delete_program)
    parsers['query'].set_defaults(func=_make_query)

    # parse the arguments

    return argumentparser.parse_args()


def _start_tunnel(local_port, remote_host, remote_port, ssh_user=''):
    """
    Creates an SSH port-forward.

    This will result in localhost:local_port appearing to be
    remote_host:remote_port.

    :param local_port: integer port number to open at localhost
    :param remote_host: string address of remote host (no port number)
    :param remote_port: port to 'open' on the remote host
    :param ssh_user: user to log in on the remote_host as

    """

    if not (isinstance(local_port, int) and local_port > 0):
        raise ValueError(
            'Local port {} must be a positive integer.'.format(local_port))
    if not (isinstance(remote_port, int) and remote_port > 0):
        raise ValueError(
            'Remote port {} must be a positive integer.'.format(remote_port))

    logging.debug('Starting SSH tunnel...')

    # this cmd string will be .format()ed in a few lines' time
    cmd = ['ssh']

    if ssh_user:
        # ssh -l {user} ... sets the remote login username
        cmd += ['-l', ssh_user]

    # -L localport:localhost:remoteport forwards the port
    # -M makes SSH able to accept slave connections
    # -S sets the location of a control socket (in this case, sextant-controller
    #    with a unique identifier appended, just in case we run sextant twice
    #    simultaneously), so we know how to close the port again
    # -f goes into background; -N does not execute a remote command;
    # -T says to remote host that we don't want a text shell.
    cmd += ['-M',
            '-S', 'sextantcontroller{tunnel_id}'.format(tunnel_id=local_port),
            '-fNT',
            '-L', '{0}:localhost:{1}'.format(local_port, remote_port),
            remote_host]

    logging.debug('Running {}'.format(' '.join(cmd)))

    exit_code = subprocess.call(cmd)
    if exit_code:
        raise OSError('SSH setup failed with error {}'.format(exit_code))

    logging.debug('SSH tunnel created.')


def _stop_tunnel(local_port, remote_host):
    """
    Tear down an SSH port-forward which was previously set up with start_tunnel.

    We use local_port as an identifier.
    :param local_port: the port on localhost we are using as the entrypoint
    :param remote_host: remote host we tunnelled into

    """

    logging.debug('Shutting down SSH tunnel...')

    # ssh -O sends a command to the slave specified in -S
    cmd = ['ssh',
           '-S', 'sextantcontroller{}'.format(local_port),
           '-O', 'exit',
           '-q',  # for quiet
           remote_host]

    # SSH has a bug on some systems which causes it to ignore the -q flag
    # meaning it prints "Exit request sent." to stderr.
    # To avoid this, we grab stderr temporarily, and see if it's that string;
    # if it is, suppress it.
    pr = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = pr.communicate()
    if stderr.rstrip() != 'Exit request sent.':
        print(stderr, file=sys.stderr)
    if pr.returncode == 0:
        logging.debug('Shut down successfully.')
    else:
        logging.warning(
            'SSH tunnel shutdown returned error code {}'.format(pr.returncode))
        logging.warning(stderr)


def _is_port_used(port):
    """
    Checks with the OS to see whether a port is open.

    Beware: port is passed directly to the shell. Make sure it is an integer.
    We raise ValueError if it is not.
    :param port: integer port to check for openness
    :return: bool(port is in use)

    """

    # we follow http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python
    if not (isinstance(port, int) and port > 0):
        raise ValueError('port {} must be a positive integer.'.format(port))

    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        sock.bind(('127.0.0.1', port))
    except socket.error as e:
        if e.errno == 98:  # Address already in use
            return True
        raise

    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

    return False  # that is, the port is not used


def _get_unused_port():
    """
    Returns a port number between 10000 and 50000 which is not currently open.

    """

    keep_going = True
    while keep_going:
        portnum = random.randint(10000, 50000)
        keep_going = _is_port_used(portnum)
    return portnum


def _get_host_and_port(url):
    """Given a URL as http://host:port, returns (host, port)."""
    parsed = parse.urlparse(url)
    return (parsed.hostname, parsed.port)


def _is_localhost(host, port):
    """
    Checks whether a host is an alias to localhost.

    Raises socket.gaierror if the host was not found.

    """

    addr = socket.getaddrinfo(host, port)[0][4][0]

    return addr in ('127.0.0.1', '::1')


def main():
    args = parse_arguments()

    if args.use_ssh_tunnel.lower() == 'true':
        localport = _get_unused_port()

        remotehost, remoteport = _get_host_and_port(args.remote_neo4j)

        try:
            is_loc = _is_localhost(remotehost, remoteport)
        except socket.gaierror:
            logging.error('Server {} not found.'.format(remotehost))
            return

        if is_loc:
            # we are attempting to connect to localhost anyway, so we won't
            # bother to SSH to it.
            # There may be some ways the user can trick us into trying to SSH
            # to localhost anyway, but this will do as a first pass.
            # SSHing to localhost is undesirable because on my test computer,
            # we get 'connection refused' if we try.
            args.func(args)

        else:  # we need to SSH
            try:
                _start_tunnel(localport, remotehost, remoteport,
                              ssh_user=args.ssh_user)
            except OSError as e:
                logging.error(str(e))
                return
            except KeyboardInterrupt:
                logging.info('Halting because of user interrupt.')
                return

            try:
                args.display_neo4j = args.remote_neo4j
                args.remote_neo4j = 'http://localhost:{}'.format(localport)
                args.func(args)
            except KeyboardInterrupt:
                # this probably happened because we were running Sextant Web
                # and Ctrl-C'ed out of it
                logging.info('Keyboard interrupt detected. Halting.')
                pass

            finally:
                _stop_tunnel(localport, remotehost)

    else:  # no need to set up the ssh, just run sextant
        args.func(args)


if __name__ == '__main__':
    main()


