"""
Implements the core :class:`Graph <kenji.graph.Graph>` interface.
"""

from contextlib import closing, contextmanager
from functools import wraps
from sqlite3 import connect

from kenji.transaction import Transaction, AbortSignal
from kenji.query import Query, V
import kenji.sql as sql

__all__ = ['Graph']


class Graph(object):
    """
    The main class that encompasses storage and deletion of the
    relations between nodes, and gives access to querying.
    """

    def __init__(self, uri=':memory:', graphs=[], unsafe=False, mapping=None):
        """
        Create a new ``Graph`` instance.

        :param uri: The URI of the SQLite db
        :param graphs: List of relations to create
        :param unsafe: Whether to turn off ``PRAGMA synchronous``
        :param mapping: A dictionary of string to iterables which
        dictates the mapping that should be respected when we
        do relation searches.
        """
        self.uri = uri
        self.db = connect(uri, check_same_thread=False, isolation_level=None)

        if unsafe:
            self.db.execute('PRAGMA synchronous = OFF;')

        with closing(self.db.cursor()) as cursor:
            for graph in graphs:
                cursor.execute(sql.CREATE_TABLE % (graph))
                for index in sql.INDEXES:
                    cursor.execute(index % (graph))
            self.db.commit()

        self.mapping = {} if mapping is None else mapping

    def store(self, edge):
        """
        Store a relation to the SQLite backend. The relation must
        already be specified during the creation of the graph.

        :param edge: An edge to store
        """
        with closing(self.db.cursor()) as cursor:
            cursor.execute(
                *sql.write_relation(edge.src, edge.relation, edge.dst)
            )
            self.db.commit()

    def delete(self, edge):
        """
        Delete (possibly multiple) edges from the database. The
        edge passed in must at the very least have a specified
        (non None) relation.

        :param edge: The edge(s) to delete.
        """
        with closing(self.db.cursor()) as cursor:
            cursor.execute(
                *sql.delete_relation(edge.src, edge.relation, edge.dst)
            )
            self.db.commit()

    def close(self):
        """
        Close the SQLite backend.
        """
        self.db.close()

    def select(self, *args, **kwargs):
        """
        Returns a new :class:`Query <kenji.query.Query>` object
        that you can iterate or perform filters on. Arguments
        are forwarded to the ``__call__`` method of the query
        object.
        """
        return Query(self.db)(*args, **kwargs)

    def exists(self, edge):
        """
        Checks if an edge exists. The edge passed in must
        specify the source and destination node as well
        as the relation.

        :param edge: The edge to check.
        """
        src, dst = edge.src, edge.dst
        with closing(self.db.cursor()) as cursor:
            cursor.execute(*sql.get_one_relation(src, edge.relation, dst))
            return bool(cursor.fetchone())

    def relation_between(self, typeof, edge):
        """
        Query the relation between two different nodes,
        with the restriction on which tables to query
        depending on the type of the nodes as defined
        in the `mapping` argument to the constructor.

        :param typeof: The type restriction.
        :param edge: An edge to query the relation for.
        """
        edge = V(src=edge.src, dst=edge.dst)
        relations = []
        for table in self.mapping[typeof]:
            edge.relation = table
            if self.exists(edge):
                relations.append(table)
        return relations

    @contextmanager
    def transaction(self):
        """
        Creates a new transaction object and returns
        it as a value available to the caller. Note
        that the Graph object is not affected in any
        way and all atomic calls must be made to the
        returned object.
        """
        transaction = Transaction(self.db)
        try:
            yield transaction
            if transaction.defined:
                transaction.commit()
        except AbortSignal:
            pass
