#!/usr/bin/env python
# coding=utf-8
"""
Implements the SimpleDB interface for MySQLdb
"""

from __future__ import generators

import MySQLdb
import MySQLdb.cursors
import _mysql_exceptions
from simpledb import SimpleDB


class SimpleMySQLdb(SimpleDB):
    def __init__(self, **kwargs):
        super(SimpleMySQLdb, self).__init__(**kwargs)
        self.auth(**kwargs)
        self.paramstyle = MySQLdb.paramstyle

    def __del__(self):
        """
        Poof!
        """
        self._disconnect()

    def auth(self, **kwargs):
        """
        Set/Update authorization
        """
        self._disconnect()
        self.authorization.update(kwargs)

    def _connect(self):
        """
        Connect to the server
        """
        if not self.dbh:
            self.dbh = MySQLdb.Connect(**self.authorization)

    def _disconnect(self):
        """
        Close the connection as best as we can
        """
        try:
            self.dbh.close()
        except Exception:
            pass
        self.dbh = None

    def _query(self, *args, **kwargs):
        """
        Query and return the cursor. Reconnect if needed
        """
        if 'cursorclass' not in kwargs:
            kwargs['cursorclass'] = MySQLdb.cursors.Cursor
        if 'reconnect' not in kwargs:
            kwargs['reconnect'] = True

        try:
            self._connect()
            cursor = self.dbh.cursor(kwargs['cursorclass'])
            cursor.execute(args[0], args[1:])
            return cursor
        except _mysql_exceptions.OperationalError:
            if not kwargs['reconnect']:
                raise _mysql_exceptions.OperationalError
            try:
                self.dbh.close()
            except Exception:
                pass
            self.dbh = None
            kwargs['reconnect'] = False
            return self._query(*args, **kwargs)

###############################################################################

    def query_dict(self, *args, **kwargs):
        def _query_dict_results():
            kwargs['cursorclass'] = MySQLdb.cursors.SSDictCursor
            cursor = self._query(*args, **kwargs)
            row = cursor.fetchone()
            while row is not None:
                yield row
                row = cursor.fetchone()
            cursor.close()
        return _query_dict_results

    def query_dict_row(self, *args, **kwargs):
        kwargs['cursorclass'] = MySQLdb.cursors.DictCursor
        cursor = self._query(*args, **kwargs)
        result = cursor.fetchone()
        if cursor.fetchone() is not None:
            raise OverflowError('Query returns more then one row')
        cursor.close()
        return result

###############################################################################

    def query_tuple(self, *args, **kwargs):
        def _query_tuple_results():
            kwargs['cursorclass'] = MySQLdb.cursors.SSCursor
            cursor = self._query(*args, **kwargs)
            row = cursor.fetchone()
            while row is not None:
                yield row
                row = cursor.fetchone()
            cursor.close()
        return _query_tuple_results

    def query_tuple_row(self, *args, **kwargs):
        kwargs['cursorclass'] = MySQLdb.cursors.Cursor
        cursor = self._query(*args, **kwargs)
        result = cursor.fetchone()
        if cursor.fetchone() is not None:
            raise OverflowError('Query returns more then one row')
        cursor.close()
        return result

###############################################################################

    def query_list(self, *args, **kwargs):
        def _query_list_results():
            kwargs['cursorclass'] = MySQLdb.cursors.SSCursor
            cursor = self._query(*args, **kwargs)
            row = cursor.fetchone()
            while row is not None:
                yield list(row)
                row = cursor.fetchone()
            cursor.close()
        return _query_list_results

    def query_list_row(self, *args, **kwargs):
        kwargs['cursorclass'] = MySQLdb.cursors.Cursor
        cursor = self._query(*args, **kwargs)
        result = cursor.fetchone()
        if cursor.fetchone() is not None:
            raise OverflowError('Query returns more then one row')
        cursor.close()
        return list(result)

###############################################################################

    def query_fields(self, *args, **kwargs):
        def _query_field_results():
            kwargs['cursorclass'] = MySQLdb.cursors.SSCursor
            cursor = self._query(*args, **kwargs)
            row = cursor.fetchone()
            if len(row) > 1:
                raise OverflowError('Too many result fields')
            while row is not None:
                yield row[0]
                row = cursor.fetchone()
            cursor.close()
        return _query_field_results

    def query_field(self, *args, **kwargs):
        kwargs['cursorclass'] = MySQLdb.cursors.Cursor
        cursor = self._query(*args, **kwargs)
        result = cursor.fetchone()
        if cursor.fetchone() is not None:
            raise OverflowError('Query returns more then one row')
        cursor.close()
        if len(result) > 1:
            raise OverflowError('Too many result fields')
        return result[0]

###############################################################################
