# -----------------------------------------
# Sextant
# Copyright 2014, Ensoft Ltd.
# Author: Patrick Stevens, using work from Patrick Stevens and James Harkin
# -----------------------------------------
# API to interact with a Neo4J server: upload, query and delete programs in a DB

__all__ = ("Validator", "AddToDatabase", "FunctionQueryResult", "Function",
           "SextantConnection")

import re  # for validation of function/program names
import logging
from datetime import datetime
import os
import getpass
from collections import namedtuple

from neo4jrestclient.client import GraphDatabase
import neo4jrestclient.client as client

COMMON_CUTOFF = 10
# a function is deemed 'common' if it has more than this
# many connections


class Validator():
    """ Sanitises/checks strings, to prevent Cypher injection attacks"""

    @staticmethod
    def validate(input_):
        """
        Checks whether we can allow a string to be passed into a Cypher query.
        :param input_: the string we wish to validate
        :return: bool(the string is allowed)
        """
        regex = re.compile(r'^[A-Za-z0-9\-:\.\$_@\*\(\)%\+,]+$')
        return bool(regex.match(input_))

    @staticmethod
    def sanitise(input_):
        """
        Strips harmful characters from the given string.
        :param input_: string to sanitise
        :return: the sanitised string
        """
        return re.sub(r'[^\.\-_a-zA-Z0-9]+', '', input_)


class AddToDatabase():
    """Updates the database, adding functions/calls to a given program"""

    def __init__(self, program_name='', sextant_connection=None,
                 uploader='', uploader_id='', date=None):
        """
        Object which can be used to add functions and calls to a new program
        :param program_name: the name of the new program to be created
          (must already be validated against Validator)
        :param sextant_connection: the SextantConnection to use for connections
        :param uploader: string identifier of user who is uploading
        :param uploader_id: string Unix user-id of logged-in user
        :param date: string date of today
        """
        # program_name must be alphanumeric, to avoid injection attacks easily
        if not Validator.validate(program_name):
            return

        self.program_name = program_name
        self.parent_database_connection = sextant_connection
        self._functions = {}
        self._new_tx = None

        if self.parent_database_connection:
            # we'll locally use db for short
            db = self.parent_database_connection._db

            parent_function = db.nodes.create(name=program_name,
                                              type='program',
                                              uploader=uploader,
                                              uploader_id=uploader_id,
                                              date=date)
            self._parent_id = parent_function.id

            self._new_tx = db.transaction(using_globals=False, for_query=True)

        self._connections = []

    def add_function(self, function_name):
        """
        Adds a function to the program, ready to be sent to the remote database.
        If the function name is already in use, this method effectively does
          nothing and returns True.

        :param function_name: a string which must be alphanumeric
        :return: True if the request succeeded, False otherwise
        """
        if not Validator.validate(function_name):
            return False
        if self.class_contains_function(function_name):
            return True

        if function_name[-4:] == "@plt":
            display_name = function_name[:-4]
            function_group = "plt_stub"
        elif function_name[:20] == "_._function_pointer_":
            display_name = function_name
            function_group = "function_pointer"
        else:
            display_name = function_name
            function_group = "normal"

        query = ('START n = node({}) '
                 'CREATE (n)-[:subject]->(m:func {{type: "{}", name: "{}"}})')
        query = query.format(self._parent_id, function_group, display_name)

        self._new_tx.append(query)

        self._functions[function_name] = function_name

        return True

    def class_contains_function(self, function_to_find):
        """
        Checks whether we contain a function with a given name.
        :param function_to_find: string name of the function we wish to look up
        :return: bool(the function exists in this AddToDatabase)
        """
        return function_to_find in self._functions

    def class_contains_call(self, function_calling, function_called):
        """
        Checks whether we contain a call between the two named functions.
        :param function_calling: string name of the calling-function
        :param function_called: string name of the called function
        :return: bool(function_calling calls function_called in us)
        """
        return (function_calling, function_called) in self._connections

    def add_function_call(self, fn_calling, fn_called):
        """
        Adds a function call to the program, ready to be sent to the database.
        Effectively does nothing if there is already a function call between
          these two functions.
        Function names must be alphanumeric for easy security purposes;
          returns False if they fail validation.
        :param fn_calling: the name of the calling-function as a string.
          It should already exist in the AddToDatabase; if it does not,
          this method will create a stub for it.
        :param fn_called: name of the function called by fn_calling.
          If it does not exist, we create a stub representation for it.
        :return: True if successful, False otherwise
        """
        if not all((Validator.validate(fn_calling),
                    Validator.validate(fn_called))):
            return False

        if not self.class_contains_function(fn_called):
            self.add_function(fn_called)
        if not self.class_contains_function(fn_calling):
            self.add_function(fn_calling)

        if not self.class_contains_call(fn_calling, fn_called):
            query = ('START p = node({}) '
                     'MATCH (p)-[:subject]->(n) WHERE n.name = "{}" '
                     'MATCH (p)-[:subject]->(m) WHERE m.name = "{}" '
                     'CREATE (n)-[:calls]->(m)')
            query = query.format(self._parent_id, fn_calling, fn_called)
            self._new_tx.append(query)

            self._connections.append((fn_calling, fn_called))

        return True

    def commit(self):
        """
        Call this when you are finished with the object.
        Changes are not synced to the remote database until this is called.
        """
        self._new_tx.commit()


class FunctionQueryResult:
    """A graph of function calls arising as the result of a Neo4J query."""

    def __init__(self, parent_db, program_name='', rest_output=None):
        self.program_name = program_name
        self._parent_db_connection = parent_db
        self.functions = self._rest_node_output_to_graph(rest_output)
        self._update_common_functions()

    def __eq__(self, other):
        # we make a dictionary so that we can perform easy comparison
        selfdict = {func.name: func for func in self.functions}
        otherdict = {func.name: func for func in other.functions}

        return self.program_name == other.program_name and selfdict == otherdict

    def _update_common_functions(self):
        """
        Loop over all functions: increment the called-by count of their callees.
        """
        for func in self.functions:
            for called in func.functions_i_call:
                called.number_calling_me += 1

    def _rest_node_output_to_graph(self, rest_output):
        """
        Convert the output of a REST API query into our internal representation.
        :param rest_output: output of the REST call as a Neo4j QuerySequence
        :return: iterable of <Function>s ready to initialise self.functions.
        """

        if rest_output is None or not rest_output.elements:
            return []

        # how we store this is: a dict
        #   with keys  'functionname'
        #   and values [the function object we will use,
        #               and a set of (function names this function calls),
        #               and numeric ID of this node in the Neo4J database]

        result = {}

        # initial pass for names of functions

        # if the following assertion failed, we've probably called db.query
        # to get it to not return client.Node objects, which is wrong.
        # we attempt to handle this a bit later; this should never arise, but
        # we can cope with it happening in some cases, like the test suite

        if type(rest_output.elements) is not list:
            logging.warning('Not a list: {}'.format(type(rest_output.elements)))

        for node_list in rest_output.elements:
            assert(isinstance(node_list, list))
            for node in node_list:
                if isinstance(node, client.Node):
                    name = node.properties['name']
                    node_id = node.id
                    node_type = node.properties['type']
                else:  # this is the handling we mentioned earlier;
                    # we are a dictionary instead of a list, as for some
                    # reason we've returned Raw rather than Node data.
                    # We should never reach this code, but just in case.
                    name = node['data']['name']
                    # hacky workaround to get the id
                    node_id = node['self'].split('/')[-1]
                    node_type = node['data']['type']

                result[name] = [Function(self.program_name,
                                         function_name=name,
                                         function_type=node_type),
                                set(),
                                node_id]

        # end initialisation of names-dictionary

        if self._parent_db_connection is not None:
            # This is the normal case, of extracting results from a server.
            # We leave the other case in because it is useful for unit testing.

            # We collect the name-name pairs of caller-callee, batched for speed
            new_tx = self._parent_db_connection.transaction(using_globals=False,
                                                            for_query=True)
            for index in result:
                q = ("START n=node({})"
                     "MATCH n-[calls:calls]->(m)"
                     "RETURN n.name, m.name").format(result[index][2])
                new_tx.append(q)

            logging.debug('exec')
            results = new_tx.execute()

            # results is a list of query results, each of those being a list of
            # calls.

            for call_list in results:
                if call_list:
                    # call_list has element 0 being an arbitrary call this
                    # function makes; element 0 of that call is the name of the
                    # function itself. Think {{'orig', 'b'}, {'orig', 'c'}}.
                    orig = call_list[0][0]
                    # result['orig'] is [<Function>, ('callee1','callee2')]
                    result[orig][1] |= set(list(zip(*call_list.elements))[1])
                    # recall: set union is denoted by |

        else:
            # we don't have a parent database connection.
            # This has probably arisen because we created this object from a
            # test suite, or something like that.
            for node in rest_output.elements:
                node_name = node[0].properties['name']
                result[node_name][1] |= {relationship.end.properties['name']
                                         for relationship in node[0].relationships.outgoing()}

        logging.debug('Relationships complete.')

        # named_function takes a function name and returns the Function object
        # with that name, or None if none exists.
        named_function = lambda name: result[name][0] if name in result else None

        for function, calls, node_id in result.values():
            what_i_call = [named_function(name)
                           for name in calls
                           if named_function(name) is not None]
            function.functions_i_call = what_i_call

        return [list_element[0]
                for list_element in result.values()
                if list_element[0]]

    def get_functions(self):
        """
        :return: a list of Function objects present in the query result
        """
        return self.functions

    def get_function(self, name):
        """
        Given a function name, returns the Function object which has that name.
        If no function with that name exists, returns None.
        """
        func_list = [func for func in self.functions if func.name == name]
        return None if len(func_list) == 0 else func_list[0]


def set_common_cutoff(common_def):
    """
    Sets the number of incoming connections at which we deem a function 'common'
    Default is 10 (which is used if this method is never called).
    :param common_def: number of incoming connections
    """
    global COMMON_CUTOFF
    COMMON_CUTOFF = common_def


class Function(object):
    """Represents a function which might appear in a FunctionQueryResult."""

    def __eq__(self, other):
        funcs_i_call_list = {func.name for func in self.functions_i_call}
        funcs_other_calls_list = {func.name for func in other.functions_i_call}

        return (self.parent_program == other.parent_program
                and self.name == other.name
                and funcs_i_call_list == funcs_other_calls_list
                and self.attributes == other.attributes)

    @property
    def number_calling_me(self):
        return self._number_calling_me

    @number_calling_me.setter
    def number_calling_me(self, value):
        self._number_calling_me = value
        self.is_common = (self._number_calling_me > COMMON_CUTOFF)

    def __init__(self, program_name='', function_name='', function_type=''):
        self.parent_program = program_name
        self.attributes = []
        self.type = function_type
        self.functions_i_call = []
        self.name = function_name
        self.is_common = False
        self._number_calling_me = 0
        # care: _number_calling_me is not automatically updated, except by
        # any invocation of FunctionQueryResult._update_common_functions.


class SextantConnection:
    """
    RESTful connection to a remote database.
    It can be used to create/delete/query programs.
    """

    ProgramWithMetadata = namedtuple('ProgramWithMetadata',
                                     ['uploader', 'uploader_id',
                                      'program_name', 'date', 
                                      'number_of_funcs'])

    def __init__(self, url):
        self.url = url
        self._db = GraphDatabase(url)

    def new_program(self, name_of_program):
        """
        Request that the remote database create a new program with the given name.
        This procedure will create a new program remotely; you can manipulate
          that program using the returned AddToDatabase object.
        The name can appear in the database already, but this is not recommended
          because then delete_program will not know which to delete. Check first
          using self.check_program_exists.
        The name specified must pass Validator.validate()ion; this is a measure
          to prevent Cypher injection attacks.
        :param name_of_program: string program name
        :return: AddToDatabase instance if successful
        """

        if not Validator.validate(name_of_program):
            raise ValueError(
                "{} is not a valid program name".format(name_of_program))
        
        uploader = getpass.getuser()
        uploader_id = os.getuid()

        return AddToDatabase(sextant_connection=self,
                             program_name=name_of_program,
                             uploader=uploader, uploader_id=uploader_id,
                             date=str(datetime.now()))

    def delete_program(self, name_of_program):
        """
        Request that the remote database delete a specified program.
        :param name_of_program: a string which must be alphanumeric only
        :return: bool(request succeeded)
        """
        if not Validator.validate(name_of_program):
            return False

        q = """MATCH (n) WHERE n.name= "{}" AND n.type="program"
        OPTIONAL MATCH (n)-[r]-(b) OPTIONAL MATCH (b)-[rel]-()
        DELETE  b,rel DELETE n, r""".format(name_of_program)

        self._db.query(q)

        return True

    def _execute_query(self, prog_name='', query=''):
        """
        Executes a Cypher query against the remote database.
        Note that this returns a FunctionQueryResult, so is unsuitable for any
          other expected outputs (such as lists of names). For those instances,
          it is better to run self._parent_database_connection_object.query
          explicitly.
        Intended only to be used for non-updating queries
          (such as "get functions" rather than "create").
        :param prog_name: name of the program the result object will reflect
        :param query: verbatim query we wish the server to execute
        :return: a FunctionQueryResult corresponding to the server's output
        """
        rest_output = self._db.query(query, returns=client.Node)

        return FunctionQueryResult(parent_db=self._db,
                                   program_name=prog_name,
                                   rest_output=rest_output)

    def get_program_names(self):
        """
        Execute query to retrieve a list of all programs in the database.
        Any name in this list can be used verbatim in any SextantConnection
          method which requires a program-name input.
        :return: a list of function-name strings.
        """
        q = """MATCH (n) WHERE n.type = "program" RETURN n.name"""
        program_names = self._db.query(q, returns=str).elements

        result = [el[0] for el in program_names]

        return set(result)

    def programs_with_metadata(self):
        """
        Returns a set of namedtuples which represent the current database.
        
        The namedtuples have .uploader, .uploader_id, .program_name, .date,
        .number_of_funcs.
        :return: set of namedtuples
       
        """
        
        q = ("MATCH (base) WHERE base.type = 'program' "
             "MATCH (base)-[:subject]->(n)"
             "RETURN base.uploader, base.uploader_id, base.name, base.date, count(n)")
        result = self._db.query(q)
        return {self.ProgramWithMetadata(*res) for res in result}

    def check_program_exists(self, program_name):
        """
        Execute query to check whether a program with the given name exists.
        Returns False if the program_name fails validation against Validator.
        :return: bool(the program exists in the database).
        """

        if not Validator.validate(program_name):
            return False

        q = ("MATCH (base) WHERE base.name = '{}' AND base.type = 'program' "
             "RETURN count(base)").format(program_name)

        result = self._db.query(q, returns=int)
        return result.elements[0][0] > 0

    def check_function_exists(self, program_name, function_name):
        """
        Execute query to check whether a function with the given name exists.
        We only check for functions which are children of a program with the
          given program_name.
        :param program_name: string name of the program within which to check
        :param function_name: string name of the function to check for existence
        :return: bool(names validate correctly, and function exists in program)
        """
        if not self.check_program_exists(program_name):
            return False

        if not Validator.validate(program_name):
            return False

        q = ("MATCH (base) WHERE base.name = '{}' AND base.type = 'program'"
             "MATCH (base)-[r:subject]->(m) WHERE m.name = '{}'"
             "RETURN count(m)").format(program_name, function_name)

        result = self._db.query(q, returns=int)
        return result.elements[0][0] > 0

    def get_function_names(self, program_name):
        """
        Execute query to retrieve a list of all functions in the program.
        Any of the output names can be used verbatim in any SextantConnection
          method which requires a function-name input.
        :param program_name: name of the program whose functions to retrieve
        :return: None if program_name doesn't exist in the remote database,
          a set of function-name strings otherwise.
        """

        if not self.check_program_exists(program_name):
            return None

        q = ("MATCH (base) WHERE base.name = '{}' AND base.type = 'program' "
             "MATCH (base)-[r:subject]->(m) "
             "RETURN  m.name").format(program_name)
        return {func[0] for func in self._db.query(q)}

    def get_all_functions_called(self, program_name, function_calling):
        """
        Execute query to find all functions called by a function (indirectly).
        If the given function is not present in the program, returns None;
          likewise if the program_name does not exist.
        :param program_name: a string name of the program we wish to query under
        :param function_calling: string name of a function whose children to find
        :return: FunctionQueryResult, maximal subgraph rooted at function_calling
        """

        if not self.check_program_exists(program_name):
            return None

        if not self.check_function_exists(program_name, function_calling):
            return None

        q = """MATCH (base) WHERE base.name = '{}' ANd base.type = 'program'
            MATCH (base)-[:subject]->(m) WHERE m.name='{}'
            MATCH (m)-[:calls*]->(n)
            RETURN distinct n, m""".format(program_name, function_calling)

        return self._execute_query(program_name, q)

    def get_all_functions_calling(self, program_name, function_called):
        """
        Execute query to find all functions which call a function (indirectly).
        If the given function is not present in the program, returns None;
          likewise if the program_name does not exist.
        :param program_name: a string name of the program we wish to query
        :param function_called: string name of a function whose parents to find
        :return: FunctionQueryResult, maximal connected subgraph with leaf function_called
        """

        if not self.check_program_exists(program_name):
            return None

        if not self.check_function_exists(program_name, function_called):
            return None

        q = """MATCH (base) WHERE base.name = '{}' AND base.type = 'program'
            MATCH (base)-[r:subject]->(m) WHERE m.name='{}'
            MATCH (n)-[:calls*]->(m) WHERE n.name <> '{}'
            RETURN distinct n , m"""
        q = q.format(program_name, function_called, program_name)

        return self._execute_query(program_name, q)

    def get_call_paths(self, program_name, function_calling, function_called):
        """
        Execute query to find all possible routes between two specific nodes.
        If the given functions are not present in the program, returns None;
          ditto if the program_name does not exist.
        :param program_name: string program name
        :param function_calling: string
        :param function_called: string
        :return: FunctionQueryResult, the union of all subgraphs reachable by
          adding a source at function_calling and a sink at function_called.
        """

        if not self.check_program_exists(program_name):
            return None

        if not self.check_function_exists(program_name, function_called):
            return None

        if not self.check_function_exists(program_name, function_calling):
            return None

        q = r"""MATCH (pr) WHERE pr.name = '{}' AND pr.type = 'program'
                MATCH p=(start {{name: "{}" }})-[:calls*]->(end {{name:"{}"}})
                  WHERE (pr)-[:subject]->(start)
                WITH DISTINCT nodes(p) AS result
                UNWIND result AS answer
                RETURN answer"""
        q = q.format(program_name, function_calling, function_called)

        return self._execute_query(program_name, q)

    def get_whole_program(self, program_name):
        """Execute query to find the entire program with a given name.
        If the program is not present in the remote database, returns None.
        :param: program_name: a string name of the program we wish to return.
        :return: a FunctionQueryResult consisting of the program graph.
        """

        if not self.check_program_exists(program_name):
            return None

        query = """MATCH (base) WHERE base.name = '{}' AND base.type = 'program'
                MATCH (base)-[subject:subject]->(m)
                RETURN DISTINCT (m)""".format(program_name)

        return self._execute_query(program_name, query)

    def get_shortest_path_between_functions(self, program_name, func1, func2):
        """
        Execute query to get a single, shortest, path between two functions.
        :param program_name: string name of the program we wish to search under
        :param func1: the name of the originating function of our shortest path
        :param func2: the name of the function at which to terminate the path
        :return: FunctionQueryResult shortest path between func1 and func2.
        """
        if not self.check_program_exists(program_name):
            return None

        if not self.check_function_exists(program_name, func1):
            return None

        if not self.check_function_exists(program_name, func2):
            return None

        q = """MATCH (func1 {{ name:"{}" }}),(func2 {{ name:"{}" }}),
            p = shortestPath((func1)-[:calls*]->(func2))
            UNWIND nodes(p) AS ans
            RETURN ans""".format(func1, func2)

        return self._execute_query(program_name, q)
