
import re

from sqlalchemy import sql, exc, util
from sqlalchemy.engine import default, reflection, ResultProxy
from sqlalchemy.sql import compiler, expression
from sqlalchemy import types as sqltypes
from pynuodb import datatype
from sqlalchemy.ext.compiler import compiles
import collections
from sqlalchemy import processors

RESERVED_WORDS = set(
    ["all", "analyse", "analyze", "and", "any", "array", "as", "asc",
    "asymmetric", "both", "case", "cast", "check", "collate", "column",
    "constraint", "create", "current_catalog", "current_date",
    "current_role", "current_time", "current_timestamp", "current_user",
    "default", "deferrable", "desc", "distinct", "do", "else", "end",
    "except", "false", "fetch", "for", "foreign", "from", "grant", "group",
    "having", "in", "initially", "intersect", "into", "leading", "limit",
    "localtime", "localtimestamp", "new", "not", "null", "off", "offset",
    "old", "on", "only", "or", "order", "placing", "primary", "references",
    "select", "session_user", "some", "symmetric", "table",
    "then", "to", "trailing", "true", "union", "unique", "user", "using",
    "variadic", "when", "where", "window", "with", "authorization",
    "between", "binary", "cross", "current_schema", "freeze", "full",
    "ilike", "inner", "is", "isnull", "join", "left", "like", "natural",
    "notnull", "outer", "over", "overlaps", "right", "similar", "verbose"
    ])

class NestedResult(sqltypes.TypeEngine):
  
    def nuodb_result_processor(self, gen_nested_context):
        def process(value):
            return ResultProxy(gen_nested_context(value))
        return process

class nested(expression.ScalarSelect):
    __visit_name__ = 'nuodb_nested'

    def __init__(self, stmt):
        if isinstance(stmt, expression.ScalarSelect):
            stmt = stmt.element
        elif not isinstance(stmt, expression.SelectBase):
            stmt = expression.select(util.to_list(stmt))

        super(nested, self).__init__(stmt)
        self.type = NestedResult()

colspecs = {
}


class NuoInteger(sqltypes.Integer):
    def get_col_spec(self):
        return "INTEGER"
 
class NuoNumeric(sqltypes.Numeric):
    """The FIXED (also NUMERIC, DECIMAL) data type."""

    def __init__(self, precision=None, scale=None, **kw):
        kw.setdefault('asdecimal', True)
        super(NuoNumeric, self).__init__(scale=scale, precision=precision,
                                         **kw)

    def bind_processor(self, dialect):
        return None
 
class NuoFloat(sqltypes.Float):
    def get_col_spec(self):
        return "FLOAT"

    def bind_processor(self, dialect):
        """By converting to string, we can use Decimal types round-trip."""
        return processors.to_str
 
class NuoDate(sqltypes.Date):
    def get_col_spec(self):
        return "DATE"
 
class NuoTime(sqltypes.Time):
    def get_col_spec(self):
        return "DATETIME"
 
class NuoString(sqltypes.String):
    def get_col_spec(self):
        return "STRING"
 
class NuoBlob(sqltypes.LargeBinary):
    def get_col_spec(self):
        return "BLOB"
 
class NuoTimestamp(sqltypes.DateTime):
    def get_col_spec(self):
        return "DATETIME"

# NuoDB SQLAlchemy Compilers

class NuoDBCompiler(compiler.SQLCompiler):
    pass

class NuoDBDDLCompiler(compiler.DDLCompiler):
    
    def get_column_specification(self, column, **kwargs):

        colspec = self.preparer.format_column(column)
        colspec += " " + self.dialect.type_compiler.process(column.type)

        if column.nullable is not None:
            if not column.nullable:
                colspec += " NOT NULL"
            else:
                colspec += " NULL"

        if column is column.table._autoincrement_column:
            colspec += " GENERATED BY DEFAULT AS IDENTITY"
            # TODO: can do start with/increment by here
            # seq_col = column.table._autoincrement_column
        else:
            default = self.get_column_default_string(column)
            if default is not None:
                colspec += " DEFAULT " + default

        return colspec

class NuoDBTypeCompiler(compiler.GenericTypeCompiler):
    def visit_text(self, type_):
        return "TEXT"

    def visit_char(self, type_):
        return "CHAR"

    def visit_date(self, type_):
        return "DATE"

    def visit_datetime(self, type_):
        return "DATETIME"

    def visit_string(self, type_):
        return "STRING"

    def visit_large_binary(self, type_):
        return "BLOB"
    
    def visit_numeric(self, type_):
        if type_.scale and type_.precision:
            return 'NUMERIC(%s, %s)' % (type_.precision, type_.scale)
        else:
            return 'NUMBER'
    
    def visit_BOOLEAN(self, type_):
        return "BOOLEAN"

class NuoDBIdentifierPreparer(compiler.IdentifierPreparer):
    reserved_words = RESERVED_WORDS

class NuoDBInspector(reflection.Inspector):
    pass

class NuoDBExecutionContext(default.DefaultExecutionContext):
    
    def _translate_colname(self, colname):
        return colname.upper(), None
    
    def fire_sequence(self, seq, type_):
        return self._execute_scalar(
                "select nextval('%s', '%s')" % (
                    seq.schema or self.dialect.default_schema_name,
                    self.dialect.identifier_preparer.format_sequence(seq)),
                type_)

colspecs = {
    'char': NuoString,
    'character': NuoString,
    'date': NuoDate,
    'fixed': NuoNumeric,
    'float': NuoFloat,
    'int': NuoInteger,
    'integer': NuoInteger,
    'long binary': NuoBlob,
    'long unicode': NuoString,
    'long': NuoString,
    'smallint': NuoInteger,
    'time': NuoTime,
    'timestamp': NuoTimestamp,
    'varchar': NuoString,
    'String': NuoString,
    }

class NuoDBDialect(default.DefaultDialect):
    name                                = 'nuodb'
    supports_alter                      = True 
    max_identifier_length               = 63
    supports_sane_rowcount              = True

    supports_native_enum                = True
    supports_native_boolean             = True

    supports_sequences                  = True
    sequences_optional                  = False
    preexecute_autoincrement_sequences  = False
    postfetch_lastrowid                 = False

    requires_name_normalize             = True
    supports_default_values             = True
    supports_empty_insert               = False
    default_paramstyle                  = 'pyformat'
    colspecs                            = colspecs

    statement_compiler                  = NuoDBCompiler
    ddl_compiler                        = NuoDBDDLCompiler
    type_compiler                       = NuoDBTypeCompiler
    preparer                            = NuoDBIdentifierPreparer
    execution_ctx_cls                   = NuoDBExecutionContext
    inspector                           = NuoDBInspector
    isolation_level                     = None
    supports_unicode_statements         = False
    supports_native_decimal             = False

    use_native_unicode                  = False
    driver                              = 'pynuodb'

    # TODO: need to inspect "standard_conforming_strings"
    _backslash_escapes                  = False

    def __init__(self, **kwargs):
        default.DefaultDialect.__init__(self, **kwargs)

    def initialize(self, connection):
        super(NuoDBDialect, self).initialize(connection)

    def on_connect(self):
        return None

    def _check_unicode_returns(self, connection):
        return False

    @classmethod
    def dbapi(cls):
        global Connection
        from pynuodb import Connection
        return __import__("pynuodb")

    def connect(self, *cargs, **cparams):
        cparams['options'] = {}
        if cparams.has_key('username'):
            cparams['user'] = cparams['username']
            del cparams['username']
        if cparams.has_key('schema'):
            cparams['options']['schema'] = cparams['schema']
            del cparams['schema']
        else:
            cparams['options']['schema'] = 'user'
        if cparams.has_key('port'):
            if cparams['host'].find(':') is 0:
                cparams['host'] = '%s:%s' % (cparams['host'], cparams['port'])
            del cparams['port']
        return self.dbapi.connect(*cargs, **cparams)

#     def create_connect_args(self, url):
#         opts = url.translate_connect_args(username='user')
#         options = {}
#         if opts.has_key('port'):
#             if opts['host'].find(':') is 0:
#                 opts['host'] = '%s:%s' % (opts['host'], opts['port'])
#             del opts['port']
#         if opts.has_key('schema'):
#             options['schema'] = opts['schema']
#             del opts['schema']
#         opts['options'] = options
#         opts.update(url.query)
#         return ([], opts)

    def _get_default_schema_name(self, connection):
        return connection.scalar("SELECT current_user FROM dual")

    def has_schema(self, connection, schema):
        query = "SELECT * FROM system.tables WHERE schema='%s'" % schema
        cursor = connection.execute(sql.text(query))
        return bool(cursor.first())

    def has_table(self, connection, table_name, schema=None):
        schema = self._verify_schema(connection, schema)
        query = "SELECT * FROM system.tables WHERE tablename='%s'" % table_name
        if schema:
            query = query + " AND schema='%s'" % schema
        cursor = connection.execute(sql.text(query))
        return bool(cursor.first())

    def has_sequence(self, connection, sequence_name, schema=None):
        schema = self._verify_schema(connection, schema)
        query = "SELECT * FROM system.sequences WHERE sequencename='%s'" % sequence_name
        if schema:
            query = query + " AND schema='%s'" % schema
        cursor = connection.execute(sql.text(query))
        return bool(cursor.first())

    def normalize_name(self, name):
        if name is None:
            return None
        # Py2K
        if isinstance(name, str):
            name = name.decode(self.encoding)
        # end Py2K
        if name.upper() == name:
            return name.lower()
        else:
            return name

    def denormalize_name(self, name):
        if name is None:
            return None
        elif name.lower() == name:
            name = name.upper()
        # Py2K
        if not self.supports_unicode_binds:
            name = name.encode(self.encoding)
        else:
            name = unicode(name)
        # end Py2K
        return name

    # effective version
    def _get_server_version_info(self, connection):
        ver = connection.scalar("SELECT geteffectiveplatformversion() FROM dual")
        return ver

    @reflection.cache
    def get_schema_names(self, connection, **kw):
        query = "SELECT schema FROM system.schemas"
        cursor = connection.execute(sql.text(query))
        return [row[0] for row in cursor.fetchall()]

    @reflection.cache
    def get_table_names(self, connection, schema=None, **kw):
        schema = self._verify_schema(connection, schema)
        query = "SELECT tablename FROM system.tables WHERE table_schema='%s'" % schema
        cursor = connection.execute(sql.text(query))
        return [row[0] for row in cursor.fetchall()]

    @reflection.cache
    def get_view_names(self, connection, schema=None, **kw):
        schema = self._verify_schema(connection, schema)
        query = "SELECT viewname FROM system.view_tables WHERE table_schema='%s'" % schema
        cursor = connection.execute(sql.text(query))
        return [row[0] for row in cursor.fetchall()]

    @reflection.cache
    def get_view_definition(self, connection, view_name, schema=None, **kw):
        schema = self._verify_schema(connection, schema)

    @reflection.cache
    def get_columns(self, connection, table_name, schema=None, **kw):
        schema = self._verify_schema(connection, schema)
        query = "SELECT field, tablename, datatype, precision, scale, defaultvalue FROM system.fields WHERE tablename='%s'" % table_name
        if schema:
            query = query + " AND schema='%s'" % schema
        cursor = connection.execute(sql.text(query))
        results = [row for row in cursor.fetchall()]
        columns = []
        for row in results:
            (field, tablename, datatype, precision, scale, default) = (self.normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5])
            if datatype in [1, 2, 3]:
                coltype = NuoString()
            elif datatype in [7, 8]:
                coltype = NuoNumeric()
            elif datatype in [10, 15]:
                coltype = NuoTimestamp()
            elif datatype in [4, 5, 6]:
                coltype = NuoInteger()
            elif datatype in [13]:
                coltype = NuoBlob()
            else:
                coltype = sqltypes.NULLTYPE

            cdict = {
                'name': field,
                'type': coltype,
                'default': default
            }
            columns.append(cdict)
        return columns

    def _verify_schema(self, connection, schema):
        if schema:
            if not self.has_schema(connection, schema):
                raise NotImplementedError("schema [%s] does not exist" % schema)
        return schema

    @reflection.cache
    def get_pk_constraint(self, connection, table_name, schema=None, **kw):
        raise NotImplementedError()

    @reflection.cache
    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
        raise NotImplementedError()

    @reflection.cache
    def get_indexes(self, connection, table_name, schema, **kw):
        raise NotImplementedError()
