import os
import urllib
import Orange
import orange
from Orange.utils import deprecated_keywords, deprecated_members
from Orange.feature import Descriptor

def _parseURI(uri):
    """ lifted straight from sqlobject """
    schema, rest = uri.split(':', 1)
    assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
    if rest.startswith('/') and not rest.startswith('//'):
        host = None
        rest = rest[1:]
    elif rest.startswith('///'):
        host = None
        rest = rest[3:]
    else:
        rest = rest[2:]
        if rest.find('/') == -1:
            host = rest
            rest = ''
        else:
            host, rest = rest.split('/', 1)
    if host and host.find('@') != -1:
        user, host = host.split('@', 1)
        if user.find(':') != -1:
            user, password = user.split(':', 1)
        else:
            password = None
    else:
        user = password = None
    if host and host.find(':') != -1:
        _host, port = host.split(':')
        try:
            port = int(port)
        except ValueError:
            raise ValueError, "port must be integer, got '%s' instead" % port
        if not (1 <= port <= 65535):
            raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
        host = _host
    else:
        port = None
    path = '/' + rest
    if os.name == 'nt':
        if (len(rest) > 1) and (rest[1] == '|'):
            path = "%s:%s" % (rest[0], rest[2:])
    args = {}
    if path.find('?') != -1:
        path, arglist = path.split('?', 1)
        arglist = arglist.split('&')
        for single in arglist:
            argname, argvalue = single.split('=', 1)
            argvalue = urllib.unquote(argvalue)
            args[argname] = argvalue
    return schema, user, password, host, port, path, args

class __MySQLQuirkFix(object):
    def __init__(self, dbmod):
        self.dbmod = dbmod
        self.typeDict = {
            Descriptor.Continuous:'DOUBLE',
            Descriptor.Discrete:'VARCHAR(250)', Descriptor.String:'VARCHAR(250)'}

    def beforeWrite(self, cursor):
        cursor.execute("SET sql_mode='ANSI_QUOTES';")

    def beforeCreate(self, cursor):
        cursor.execute("SET sql_mode='ANSI_QUOTES';")

class __PostgresQuirkFix(object):
    def __init__(self, dbmod):
        self.dbmod = dbmod
        self.typeDict = {
            Descriptor.Continuous:'FLOAT',
            Descriptor.Discrete:'VARCHAR', Descriptor.String:'VARCHAR'}

    def beforeWrite(self, cursor):
        pass

    def beforeCreate(self, cursor):
        pass

class __ODBCQuirkFix(object):
    def __init__(self, dbmod):
        self.dbmod = dbmod
        self.typeDict = {
            Descriptor.Continuous:'FLOAT',
            Descriptor.Discrete:'VARCHAR', Descriptor.String:'VARCHAR'}

    def beforeWrite(self, cursor):
        pass

    def beforeCreate(self, cursor):
        pass


def _connection(uri):
        (schema, user, password, host, port, path, args) = _parseURI(uri)
        argTrans = {
            'host':'host',
            'port':'port',
            'user':'user',
            'password':'passwd',
            'database':'db'
            }
        if schema == 'postgres':
            argTrans["database"] = "db"
        elif schema == 'odbc':
            argTrans["host"] = "server"
            argTrans["user"] = "uid"
            argTrans["password"] = "pwd"
            argTrans['database'] = 'database'

        dbArgDict = {}
        if user:
            dbArgDict[argTrans['user']] = user
        if password:
            dbArgDict[argTrans['password']] = password
        if host:
            dbArgDict[argTrans['host']] = host
        if port:
            dbArgDict[argTrans['port']] = port
        if path:
            dbArgDict[argTrans['database']] = path[1:]

        if schema == 'postgres':
            import psycopg2 as dbmod
            quirks = __PostgresQuirkFix(dbmod)
            quirks.parameter = "%s"
            return (quirks, dbmod.connect(**dbArgDict))
        elif schema == 'mysql':
            import MySQLdb as dbmod
            quirks = __MySQLQuirkFix(dbmod)
            quirks.parameter = "%s"
            return (quirks, dbmod.connect(**dbArgDict))
        elif schema == "sqlite":
            import sqlite3 as dbmod
            quirks = __PostgresQuirkFix(dbmod)
            quirks.parameter = "?"
            return (quirks, dbmod.connect(host))
        elif schema == "odbc":
            import pyodbc as dbmod
            quirks = __ODBCQuirkFix(dbmod)
            quirks.parameter = "?"
            if args.has_key('DSN'):
                connectionString = 'DSN=%s' % (args['DSN'])
            elif args.has_key('Driver'):
                connectionString = 'Driver=%s' % (args['Driver'])
            else:
                raise ValueError, "ODBC url schema must have DSN or Driver parameter"
            for k in args:
                if k not in ['DSN','Driver']:
                    connectionString +=';%s=%s' % (k,args[k])
            #print connectionString, dbArgDict
            return (quirks, dbmod.connect(connectionString,**dbArgDict))

class SQLReader(object):
    """
    :obj:`~SQLReader` establishes a connection with a database and provides the methods needed
    to fetch the data from the database into Orange.
    """
    @deprecated_keywords({"domainDepot":"domain_depot"})
    def __init__(self, addr = None, domain_depot = None):
        """
        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters]).
        :type uri: str

        :param domain_depot: Domain depot
        :type domain_depot: :class:`orange.DomainDepot`
        """
        if addr is not None:
            self.connect(addr)
        if domain_depot is not None:
            self.domainDepot = domain_depot
        else:
            self.domainDepot = orange.DomainDepot()
        self.exampleTable = None
        self._dirty = True

    def connect(self, uri):
        """
        Connect to the database.

        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
        :type uri: str
        """
        self._dirty = True
        self.del_domain()
        (self.quirks, self.conn) = _connection(uri)

    def disconnect(self):
        """
        Disconnect from the database.
        """
        func = getattr(self.conn, "disconnect", None)
        if callable(func):
            self.conn.disconnect()

    def get_class_name(self):
        self.update()
        return self.domain.class_var.name

    def set_class_name(self, class_name):
        self._className = class_name
        self.del_domain()

    def del_class_name(self):
        del self._className

    class_name = property(get_class_name, set_class_name, del_class_name, "Name of class variable.")

    def get_metas_name(self):
        self.update()
        return self.domain.get_metas().values()

    def set_metas_name(self, meta_names):
        self._metaNames = meta_names
        self.del_domain()

    def del_metas_name(self):
        del self._metaNames

    meta_names = property(get_metas_name, set_metas_name, del_metas_name, "Names of meta attributes.")

    def set_discrete_names(self, discrete_names):
        self._discreteNames = discrete_names
        self.del_domain()

    def get_discrete_names(self):
        self.update()
        return self._discreteNames

    def del_discrete_names(self):
        del self._discreteNames

    discrete_names = property(get_discrete_names, set_discrete_names, del_discrete_names, "Names of discrete attributes.")

    def set_query(self, query, domain = None):
        #sets the query, resets the internal variables, without executing the query
        self._query = query
        self._dirty = True
        if domain is not None:
            self._domain = domain
        else:
            self.del_domain()

    def get_query(self):
        return self._query

    def del_query(self):
        del self._query

    query = property(get_query, set_query, del_query, "Query to be executed on the next execute().")

    def generateDomain(self):
        pass

    def set_domain(self, domain):
        self._domain = domain
        self._dirty = True

    def get_domain(self):
        if not hasattr(self, '_domain') or self._domain is None:
            self._createDomain()
        return self._domain

    def del_domain(self):
        if hasattr(self, '_domain'):
            del self._domain

    domain = property(get_domain, set_domain, del_domain, "Orange domain.")

    def execute(self, query, domain = None):
        """
        Executes an sql query.
        """
        self.set_query(query, domain)
        self.update()

    def _createDomain(self):
        if hasattr(self, '_domain') and not self._domain is None:
            return
        attrNames = []
        if not hasattr(self, '_discreteNames'):
            self._discreteNames = []
        discreteNames = self._discreteNames
        if not hasattr(self, '_metaNames'):
            self._metaNames = []
        metaNames = self._metaNames
        if not hasattr(self, '_className'):
            className = None
        else:
            className = self._className
        for i in self.desc:
            name = i[0]
            typ = i[1]
            if name in discreteNames:
                attrName = 'D#' + name
            elif typ is None or typ in [unicode, self.quirks.dbmod.STRING, self.quirks.dbmod.DATETIME]:
                    attrName = 'S#' + name
            else:
                attrName = 'C#' + name

            if name == className:
                attrName = "c" + attrName
            elif name in metaNames:
                attrName = "m" + attrName
            elif not className and name == 'class':
                attrName = "c" + attrName
            attrNames.append(attrName)
        (self._domain, self._metaIDs, dummy) = self.domainDepot.prepareDomain(attrNames)
        del dummy

    def update(self):
        """
        Execute a pending SQL query.
        """
        if not self._dirty and hasattr(self, '_domain') and not self._domain is None:
            return self.exampleTable
        self.exampleTable = None
        try:
            curs = self.conn.cursor()
            try:
                curs.execute(self.query)
            except Exception, e:
                self.conn.rollback()
                raise e
            self.desc = curs.description
            # for reasons unknown, the attributes get reordered.
            domainIndexes = [0] * len(self.desc)
            self._createDomain()
            attrNames = []
            for i, name in enumerate(self.desc):
                #print name[0], '->', self._domain.index(name[0])
                domainIndexes[self._domain.index(name[0])] = i
                attrNames.append(name[0])
            self.exampleTable = Orange.data.Table(self.domain)
            r = curs.fetchone()
            while r:
                # for reasons unknown, domain rearranges the properties
                example = Orange.data.Instance(self.domain)
                for i in xrange(len(r)):
                    val = str(r[i])
                    var = example[attrNames[i]].variable
                    if type(var) == Descriptor.Discrete and val not in var.values:
                        var.values.append(val)
                    example[attrNames[i]] = str(r[i])
                self.exampleTable.append(example)
                r = curs.fetchone()
            self._dirty = False
        except Exception, e:
            self.domain = None
            raise
            #self.domain = None

    def data(self):
        """
        Return :class:`Orange.data.Table` produced by the last executed query.
        """
        self.update()
        if self.exampleTable:
            return self.exampleTable
        return None

SQLReader = deprecated_members({"discreteNames":"discrete_names", "metaName":"meta_names"\
    , "className":"class_name"})(SQLReader)

class SQLWriter(object):
    """
    Establishes a connection with a database and provides the methods needed to create
    an appropriate table in the database and/or write the data from an :class:`Orange.data.Table`
    into the database.
    """
    def __init__(self, uri = None):
        """
        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
        :type uri: str
        """
        if uri is not None:
            self.connect(uri)

    def connect(self, uri):
        """
        Connect to the database.

        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
        :type uri: str
        """
        (self.quirks, self.connection) = _connection(uri)

    def __attrVal2sql(self, d):
        if d.var_type == Descriptor.Continuous:
            return d.value
        elif d.var_type == Descriptor.Discrete:
            return str(d.value)
        else:
            return "'%s'" % str(d.value)

    def __attrName2sql(self, d):
        return d.name

    def __attrType2sql(self, d):
        return self.quirks.typeDict[d]

    @deprecated_keywords({"renameDict":"rename_dict"})
    def write(self, table, instances, rename_dict = None):
        """
        Writes the data into the table.


        :param table: Table name.
        :type table: str

        :param instances: Data to be written into the database.
        :type instances: :class:`Orange.data.Table`

        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
        :type rename_dict: dict

        """
        l = [i.name for i in instances.domain.attributes]
        l += [i.name for i in instances.domain.get_metas().values()]
        if instances.domain.class_var:
            l.append(instances.domain.class_var.name)
        if rename_dict is None:
            rename_dict = {}
        colList = []
        for i in l:
            colList.append(rename_dict.get(str(i), str(i)))
        try:
            cursor=self.connection.cursor()
            self.quirks.beforeWrite(cursor)
            query = 'INSERT INTO "%s" (%s) VALUES (%s);'
            for d in instances:
                valList = []
                colSList = []
                for (i, name) in enumerate(colList):
                    colSList.append('"%s"'% name)
                    valList.append(self.__attrVal2sql(d[l[i]]))
                d = query % (table,
                    ", ".join(colSList),
                    ", ".join ([self.quirks.parameter] * len(valList)))
                cursor.execute(d, tuple(valList))
            cursor.close()
            self.connection.commit()
        except Exception, e:
            import traceback
            traceback.print_exc()
            self.connection.rollback()

    @deprecated_keywords({"renameDict":"rename_dict", "typeDict":"type_dict"})
    def create(self, table, instances, rename_dict = {}, type_dict = {}):
        """
        Create the required SQL table, then write the data into it.

        :param table: Table name
        :type table: str

        :param instances: Data to be written into the database.
        :type instances: :class:`Orange.data.Table`

        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
        :type rename_dict: dict

        :param type_dict: When ``type_dict`` is provided the used variables are casted into new types.
            The type of orange attribute "X" is casted into the database column of type rename_dict["X"].
        :type type_dict: dict

        """
        l = [(i.name, i.var_type ) for i in instances.domain.attributes]
        l += [(i.name, i.var_type ) for i in instances.domain.get_metas().values()]
        if instances.domain.class_var:
            l.append((instances.domain.class_var.name, instances.domain.class_var.var_type))
        #if rename_dict is None:
        #    rename_dict = {}
        colNameList = [rename_dict.get(str(i[0]), str(i[0])) for i in l]
        #if type_dict is None:
        #    typeDict = {}
        colTypeList = [type_dict.get(str(i[0]), self.__attrType2sql(i[1])) for i in l]
        try:
            cursor = self.connection.cursor()
            colSList = []
            for (i, name) in enumerate(colNameList):
                colSList.append('"%s" %s' % (name, colTypeList[i]))
            colStr = ", ".join(colSList)
            query = """CREATE TABLE "%s" ( %s );""" % (table, colStr)
            self.quirks.beforeCreate(cursor)
            cursor.execute(query)
            self.write(table, instances, rename_dict)
            self.connection.commit()
        except Exception, e:
            self.connection.rollback()

    def disconnect(self):
        """
        Disconnect from the database.
        """
        func = getattr(self.conn, "disconnect", None)
        if callable(func):
            self.conn.disconnect()

def loadSQL(filename, dontCheckStored = False, domain = None):
    f = open(filename)
    lines = f.readlines()
    queryLines = []
    discreteNames = None
    uri = None
    metaNames = None
    className = None
    for i in lines:
        if i.startswith("--orng"):
            (dummy, command, line) = i.split(None, 2)
            if command == 'uri':
                uri = eval(line)
            elif command == 'discrete':
                discreteNames = eval(line)
            elif command == 'meta':
                metaNames = eval(line)
            elif command == 'class':
                className = eval(line)
            else:
                queryLines.append(i)
        else:
            queryLines.append(i)
    query = "\n".join(queryLines)
    r = SQLReader(uri)
    if discreteNames:
        r.discreteNames = discreteNames
    if className:
        r.className = className
    if metaNames:
        r.metaNames = metaNames
    r.execute(query)
    return r.data()
