"""Parsers and writes to convert DataFrames in various (text) file formats.

"""
import csv
import gzip
import zlib
import itertools
import math
import os
import numpy
import sqlite3

from dataframe import DataFrame
from factors import Factor

sqlite3.register_adapter(numpy.int32, int)
sqlite3.register_adapter(numpy.int64, int)
sqlite3.register_adapter(numpy.float64, float)
sqlite3.register_adapter(numpy.ma.core.MaskedArray, lambda x: numpy.nan)

class ParserError(ValueError):
    pass


def _intify(na_text):
    """Return a function that converts x to int or nan if x == naText"""
    def convert(value):
        """inner converter"""
        if value == na_text:
            return numpy.nan
        else:
            return int(value)
    return convert


def _floatify(na_text):
    """Return a function that converts x to float or nan if x == naText"""
    def convert(value):
        """inner converter"""
        if value == na_text:
            return numpy.nan
        else:
            return float(value)
    return convert


def _unquote(quote_char, row):
    """Remove quote characters from all value in row if present."""
    for i in xrange(0, len(row)):
        if row[i] and row[i][0] == quote_char and row[i][-1] == quote_char:
            row[i] = row[i][1:-1]


def _unquote_single_value(quote_char, value):
    """Remove quote characters at the beginning and end if present."""
    if value[0] == quote_char and value[-1] == quote_char:
        return value[1:-1]
    else:
        return value


def _open_file(filename_or_file_object, mode = 'rb'):
    """Smartly open a file.

        Takes either an open file (returning it),
        a filename without .gz (opens it)
        or a filename with .gz (opens it as gzip file).

    """
    if not hasattr(filename_or_file_object, 'readlines'):
        # inspired by R's read.csv
        if filename_or_file_object.endswith('.gz'):
            filehandle = gzip.GzipFile(filename_or_file_object, mode)
        else:
            filehandle = open(filename_or_file_object, mode)
    else:
        filehandle = filename_or_file_object
    return filehandle


class DF2CSV:
    """A converter for comma-seperated-value files and DataFrames."""

    def read(self, filename_or_file_object, header=True, dialect=None, skip_rows = 0, columns_to_include = None, columns_to_execlude = None, na_text = "NA", type_hints = None, handle_quotes = False, has_row_names = False, r_skipped_first_column_name = False, comment_character = None):
        """Read a csv file and turn it into a DataFrame.

        dialect = a dialect (object) valid for csv.reader
        skipRows = skip n rows at the start of the file
        columnsToInclude = column names if you want just a few
        columnsToExeclude = column names if you want to exclude some
        naText = entries matching will be converted to NaN
        typeHints = {columnName: numpy.dtype} if you want to enforce
            some types a priori.
        handleQuotes = if your csv is quoted (outside of the header)
                if it isn't it can be read faster.
        r_skipped_first_column_name = if you export from R with write.table(row.names=TRUE),
            r skips the first column name. This handles this.

        """
        try:
            filehandle = _open_file(filename_or_file_object,'rU') 
            return self.__encapsulated_read(filehandle,
                                           header, dialect, skip_rows, columns_to_include, columns_to_execlude, na_text, type_hints, handle_quotes, has_row_names, r_skipped_first_column_name, comment_character)
        except zlib.error:
            filehandle = _open_file(filename_or_file_object,'r') 
            return self.__encapsulated_read(filehandle,
                                           header, dialect, skip_rows, columns_to_include, columns_to_execlude, na_text, type_hints, handle_quotes, has_row_names, r_skipped_first_column_name, comment_character)

    def __encapsulated_read(self, filehandle, header=True, dialect=None, skip_rows = 0, columns_to_include = None, columns_to_execlude = None, na_text = "NA", type_hints = None, handle_quotes = False, has_row_names = False, r_skipped_first_column_name = False, comment_character = None):
        if dialect is None:
            dialect = 'excel'
        if isinstance(dialect, str):
            dialect = csv.get_dialect(dialect)

        if skip_rows: #skip's the first few rows
            for i in xrange(0, skip_rows):
                filehandle.readline()
        if handle_quotes:
            reader = csv.reader(filehandle, dialect=dialect)
            split_lines = [ row for row in reader ] # newline problems sometimes happen here - de-Macify...
            if comment_character:
                while split_lines[0][0] and split_lines[0][0][0] == comment_character:
                    split_lines = split_lines[1:]
            first_line = split_lines[0]
            if not (header is None or header is False):
                split_lines.pop(0)
            columns = list(itertools.izip_longest(*split_lines, fillvalue=na_text))
        else:
            current_pos = filehandle.tell()
            first_line = filehandle.readline().strip().split(dialect.delimiter)
            if comment_character:
                while first_line[0].startswith(comment_character):
                    first_line = filehandle.readline().strip().split(dialect.delimiter)
            if header is None or header is False:
                filehandle.seek(current_pos, os.SEEK_SET)
            column_no = len(first_line)
            if r_skipped_first_column_name:
                column_no += 1
            columns = self._csvToColumns(filehandle, dialect, column_no, na_text, comment_character)
        filehandle.close()
        # check/convert values
        if header:
            _unquote(dialect.quotechar, first_line)
            if header is True:
                fields = first_line
            else:
                fields = header
            if r_skipped_first_column_name:
                fields.insert(0,"Row")
            if len(fields) != len(first_line):
                should = len(first_line)
                if r_skipped_first_column_name:
                    should += 1
                raise ValueError("Wrong number of headers. Should be %i, was %i (potentially including skipped first column)" % (should, len(fields)))
        else:
            num_fields = len(first_line)
            num_chars = int(math.ceil(math.log10(num_fields)))
            name_string = "column%0"+str(num_chars)+"d"
            fields = [ name_string % (i) for i in range(len(first_line)) ]
       
        # check/convert values
        if type_hints is None:
            type_hints = {}
        for i in xrange(0, len(columns)):
            try:
                columns[i] = map(_intify(na_text), columns[i])
                if not fields[i] in type_hints:
                    type_hints[fields[i]] = numpy.int32
            except ValueError:
                try:
                    columns[i] = map(_floatify(na_text), columns[i]) 
                except ValueError:
                    if i >= len(fields):
                        fields.append('Column_%i' % i)
                    if not fields[i] in type_hints:#keep as strings
                        max_len = len(max(columns[i], key=len))
                        dtype = 'S%i' % max_len
                        type_hints[fields[i]] = numpy.dtype


        results = {}
        fields_used = []
        for i in range(len(fields)):
            field = fields[i]
            try:
                if (columns_to_include is None or field in columns_to_include) and (columns_to_execlude is None or (not field in columns_to_execlude)):
                    fields_used.append(field)
                    if field in type_hints:
                        try:
                            results[field] = numpy.array(columns[i], type_hints[field])
                        except ValueError:
                            if type_hints[field] is numpy.int32:
                                results[field] = numpy.array(columns[i], numpy.float64)
                            else:
                                raise
                    elif isinstance(columns[i], numpy.ma.core.MaskedArray):
                        results[field] = (columns[i])
                    else:
                        results[field] = list(columns[i])
            except IndexError:
                raise ValueError("Could not find column for field %s, file: %s" % (field, filehandle.name) )
        row_names = None
        if has_row_names:
            if isinstance(has_row_names, str): 
                row_names = results[has_row_names]
                del results[has_row_names]
                fields_used.remove(has_row_names)
            else:
                row_names = results[fields_used[0]]
                del results[fields_used[0]]
                del fields_used[0]
        return DataFrame( results, columns_ordered=fields_used, row_names_ordered = row_names )

    def write(self, data_frame, filename_or_file_object, dialect='excel'):
        """Write a DataFrame to a comma seperated value file."""
        filehandle = _open_file(filename_or_file_object, 'wb')
        writer = csv.writer(filehandle, dialect=dialect)
        if not data_frame.row_names is None:
            header = ["Row"]
            header.extend(data_frame.columns_ordered)
        else:
            header = data_frame.columns_ordered 
        writer.writerow(header)
        if not data_frame.row_names is None:
            rowNames = data_frame.row_names
            for i in range(data_frame.num_rows):
                row = [rowNames[i]]
                row.extend(data_frame.get_row_as_list(i))
                writer.writerow(row)
        else:
            for i in range(data_frame.num_rows):
                writer.writerow(data_frame.get_row_as_list(i))
        filehandle.close()


    def _csvToColumns(self, file_object, dialect, number_of_columns, na_text, comment_character):
        """Turn a csv file into a list of columns."""
        columns = []
        for col_no in xrange(0, number_of_columns):
            columns.append([])
        sep = dialect.delimiter
        for row in file_object:
            if comment_character and row.startswith(comment_character):
                continue
            row = row.strip().split(sep)
            #_unquote(dialect.quotechar, row)
            for i in xrange(0, number_of_columns):
                try:
                    columns[i].append(row[i])
                except IndexError:
                    columns[i].append(na_text)
        return columns


class DF2ARFF:
    """ARFF file format (Weka) to DataFrame converter"""

    def read(self, filename_or_file_object):
        """Read an ARFF file.

        """
        filehandle = _open_file(filename_or_file_object,'rb')
        attributes = []
        data_found = False
        rows = filehandle.readlines()
        filehandle.close()
        row_no = 0
        for row_no, row in enumerate(rows):
            if row.startswith('@attribute'):
                row = row.split()
                name = row[1]
                type_ = row[2]
                attributes.append((name, type_))
            elif row.startswith('@data'):
                data_found = True
                break
        if not data_found:
            raise ValueError("Not a valid arff file - no @data found")
        if not attributes:
            raise ValueError("Not a valid arff file - no @attribute found")
        data_start_row = row_no + 1
        num_rows = len(rows) - data_start_row 

        field_names = zip(*attributes)[0]
        fields = {}
        fields_in_order = []
        field_funcs_in_order = []
        for name, type_ in attributes:
            if type_ == 'numeric':
                fields[name] = numpy.ma.zeros((num_rows, ), dtype=numpy.double)
                field_funcs_in_order.append(float)
            elif type_.startswith('{') and type_.endswith('}'):
                values = type_[1:-1].split(",") #todo: this should really parse the values instead, they're quoted after all
                _unquote("'", values)
                max_len = len(max(values, key=len))
                fields[name] = numpy.ma.zeros((num_rows, ), dtype = 'S%i' % max_len)

                field_funcs_in_order.append(lambda x: _unquote_single_value("'", x))
            else:
                raise ValueError("Don't know yet how to handle type %s"  % type_)
            fields_in_order.append(fields[name])
        for row_no, row in enumerate(rows[data_start_row:]):
            row = row.strip().split(',')
            for i, value in enumerate(row):
                fields_in_order[i][row_no] = field_funcs_in_order[i](value)
        return DataFrame(fields, field_names)

    def write(self, data_frame, filename_or_file_object):
        """Write a DataFrame to an ARFF file for Weak.
        """
        raise NotImplementedError()

class DF2Excel:
    """Use xlrd and xlwt to write 'real' excel files"""

    def read(self, filename, sheet_name = None, row_offset = 0, col_offset = 0, row_max = None, col_max = None, filter_empty = False, NA_string = None, handle_encoding_errors = lambda x: str(x).strip()):
        import xlrd
        wb = xlrd.open_workbook(filename)
        if sheet_name is None:
            ws = wb.sheet_by_index(0)
        else:
            if type(sheet_name) == int:
                ws = wb.sheet_by_index(sheet_name)
            else:
                ws = wb.sheet_by_name(sheet_name)
        if row_max is None:
            row_max = ws.nrows
        if col_max is None:
            col_max = ws.ncols
        cols = {}
        col_names_in_order = []
        for col_name in ws.row(0 + row_offset)[col_offset: col_max]:
            name = col_name.value
            counter = 2
            while name in col_names_in_order:
                name = "%s_%i" % ( col_name.value, counter )
                counter += 1
            col_names_in_order.append(name)
            cols[name] = []
        for row_no in xrange(0 + row_offset + 1, row_max):
            row = ws.row(row_no)
            found = False
            for y in xrange(0 + col_offset, col_max):
                value = row[y].value
                try:
                    if str(value).strip() == '' or row[y].ctype == xlrd.XL_CELL_EMPTY or row[y].ctype == xlrd.XL_CELL_ERROR:
                        value = None
      #              cols[col_names_in_order[y - col_offset]].append(str(value))
                    if type(value) == unicode and value == NA_string:
                        value = None
                except UnicodeEncodeError:
                    handle_encoding_errors(value)
                cols[col_names_in_order[y - col_offset]].append((value))
                if value:
                    found = True
            if not found and filter_empty:
                for k in col_names_in_order:
                    cols[k].pop()
        return DataFrame(cols, col_names_in_order)



    def write(self, data_frame_or_dict_of_dataframes, filename_or_file_object, sheet_name = "DataFrame",
             highlight_columns = None):
        if isinstance(data_frame_or_dict_of_dataframes, DataFrame):
            data_frame_or_dict_of_dataframes = {sheet_name: data_frame_or_dict_of_dataframes}
        elif isinstance(data_frame_or_dict_of_dataframes, dict):
            pass
        else:
            raise ValueError("DF2Excel only writes out Dataframes, or dicts of {sheet_name: dataframe}")
        import xlwt
        filehandle = _open_file(filename_or_file_object, 'wb')
        wb = xlwt.Workbook()
        style_normal = xlwt.XFStyle() 
        style_highlight = style1 = xlwt.easyxf(""" 
             pattern: 
                 back_colour yellow, 
                 pattern solid, 
                 fore-colour lavender 
         """) 
        for sheet_name, data_frame in data_frame_or_dict_of_dataframes.items():
            if data_frame.num_rows >= 65535:
                raise ParserError("Too many rows, excel only supports 65k")
            ws = wb.add_sheet(sheet_name)
            i_col = 0
            i_row = 0
            #ws.write(i_col, i_row, "Row")
            #i_col += 1
            translate_value = self.translate_value
            for column_name in data_frame.columns_ordered:
                this_column = data_frame.get_column_view(column_name)
                if isinstance(this_column,Factor):
                    this_column = this_column.as_levels()
                if highlight_columns and column_name in highlight_columns:
                    style = style_highlight
                else:
                    style = style_normal
                ws.write(i_row, i_col, column_name, style)
                for i in range(0, data_frame.num_rows):
                    ws.write(i_row + i + 1, i_col, translate_value(this_column[i]), style)
                i_col += 1
                i_row = 0

        wb.save(filehandle)

    def translate_value(self, value):
        typ = type(value)
        if typ is str:
            if not value:
                return ""
            else:
                return unicode(value)
        elif typ is float:
            if numpy.isnan(value):
                return None
            else:
                return value
        elif typ is numpy.float64:
            if numpy.isnan(value):
                return None
            else:
                return float(value)
        elif typ is int:
            return value
        elif typ is numpy.int32 or typ is numpy.int64:
            return int(value)
        elif typ is bool:
            return value
        elif value is None:
            return ""
        elif typ is numpy.ma.core.MaskedArray and value.mask:
            return ""
        else:
            try:
                return float(value)
            except:
                return unicode(value)

class DF2Sqlite:

    def write(self, data_frame_or_dict_of_dataframes, filename_or_file_object, sheet_name = 'DataFrame'):
        if type(filename_or_file_object) is str and os.path.exists(filename_or_file_object):
            os.unlink(filename_or_file_object)
        conn = sqlite3.connect(filename_or_file_object,isolation_level="DEFERRED")
        cur = conn.cursor()
        if type(data_frame_or_dict_of_dataframes) is DataFrame:
            data_frame_or_dict_of_dataframes = {sheet_name: data_frame_or_dict_of_dataframes}
        cur.execute('BEGIN TRANSACTION')
        for table_name, data_frame in data_frame_or_dict_of_dataframes.items():
            #print self._create_table_statement(table_name, data_frame)
            cur.execute(self._create_table_statement(table_name, data_frame))
            ins_statement = "INSERT INTO '%s' VALUES (%s)" % (table_name, ",".join(["?"] * len(data_frame.columns_ordered)))
            cur.executemany(ins_statement, data_frame.iter_rows_list())
        conn.commit()
        conn.close()

    def map_df_values(self, data_frame):
        int_types = (numpy.int32, numpy.int64)
        float_types = (numpy.float64,)
        for row in data_frame.iter_rows_list():
            res = []
            for val in row:
                if type(val) in int_types:
                    val = int(val)
                elif type(val) in float_types:
                    val = float(val)
                elif type(val) is numpy.ma.core.MaskedArray:
                    val = numpy.nan
                res.append(val)
            yield res


    def _create_table_statement(self, table_name, data_frame):
        field_defs = []
        for column_name in data_frame.columns_ordered:
            dtype = data_frame.value_dict[column_name].dtype
            if dtype in (numpy.int32, numpy.uint32, numpy.int64, numpy.uint64, numpy.int8, numpy.uint8, numpy.int16, numpy.uint16):
                column_type = 'int'
            elif dtype in (numpy.float, numpy.float64):
                column_type = 'real'
            else:
                column_type = 'text'
            field_defs.append("'%s' %s" % (column_name, column_type))
        statement = 'CREATE TABLE \'%s\' (%s)' % (table_name, ",\n".join(field_defs)) 
        return statement





class Access2000Dialect(csv.Dialect):
    """A dialect to properly interpret Microsoft Access2000 CSV exports 
    for international languages.
    """
    delimiter = ';'
    quotechar = '"'
    doublequote = True
    quoting = csv.QUOTE_NONNUMERIC
    lineterminator = '\n'
    skipinitialspace = True

class TabDialect(csv.Dialect):
    """A dialect to interpret tab separated files.
    """
    delimiter = '\t'
    quotechar = '"'
    doublequote = True
    quoting = csv.QUOTE_MINIMAL
    lineterminator = '\n'
    skipinitialspace = True

class CommaDialect(csv.Dialect):
    """A dialect to interpret tab separated files.
    """
    delimiter = ','
    quotechar = '"'
    doublequote = True
    quoting = csv.QUOTE_MINIMAL
    lineterminator = '\n'
    skipinitialspace = True


class SpaceDialect(csv.Dialect):
    """A dialect to interpret tab separated files.
    """
    delimiter = ' '
    quotechar = '"'
    doublequote = True
    quoting = csv.QUOTE_MINIMAL
    lineterminator = '\n'
    skipinitialspace = True

