from tropofy.file_io.read_write_csv import ClassTabularCsvMapping, CsvReader, CsvWriter
from os import remove
from sqlalchemy import Column, Table
from sqlalchemy.schema import CreateTable, DropTable
from datetime import datetime
from tropofy.file_io.exceptions import TropofyFileImportExportException
import csv
from csv import DictReader, DictWriter, writer, reader
import subprocess
import transaction

class ClassTabularCsvMappingPsql(ClassTabularCsvMapping):
    """Maps a Python class to a tabular csv file.

    Used to both write to and read from csv with :func:`tropofy.file_io.read_write_xl.CsvReaderPsql` and :func:`tropofy.file_io.read_write_xl.CsvWriterPsql`.

    :param class\_: Python class in mapping.
    :type class\_: class
    :param attribute_column_aliases: dict of class attribute to col alias.
    :type attribute_column_aliases: dict

    .. note:: To specify the order in which columns are written, use a collections.OrderedDict (Python built-in) for ``attribute_column_aliases``.
    """
    def __init__(self, class_, attribute_column_aliases):
        super(ClassTabularCsvMapping, self).__init__(class_, attribute_column_aliases, process_objects=None, get_objects=None, objects=None)

    def tabular_data_has_required_columns(self, tabular_data_column_names, data_container_name):
        """
        :param tabular_data_column_names: List of column names in tabular data
        :type tabular_data_column_names: list
        :param data_container_name: Name of data container. E.g, an Excel worksheet name, or a CSV file name.
        :type data_container_name: str
        """
        required_column_names_set = set(self.required_column_names)
        tabular_data_column_names_set = set(tabular_data_column_names)
        if required_column_names_set != tabular_data_column_names_set:
            raise TropofyFileImportExportException(
                'The first row in a sheet must exactly match the column names for the table. Sheet {container} is inconsistent for the following columns: {columns}'.format(
                container=data_container_name,
                columns=", ".join(list((required_column_names_set - tabular_data_column_names_set).union(tabular_data_column_names_set - required_column_names_set))),
            ))
        columns_out_of_order = [self.required_column_names[i] for i in range(0,len(self.required_column_names)) if self.required_column_names[i] != tabular_data_column_names[i]]
        if columns_out_of_order:
            raise TropofyFileImportExportException(
                'Columns in csv must match columns in data set. Sheet {container} is inconsistent for the following columns: {columns}'.format(
                container=data_container_name,
                columns=", ".join(columns_out_of_order),
                ))
        return True


class CsvReaderPsql(CsvReader):
    """
    Utility class for reading data from .csv files. Works with PostgreSQL databases only.  Faster than CsvReader.
    the class_tabular_csv_mapping paramater for all methods needs to be a ClassTabularCsvMappingPsql object, which has no methods for processing or inspecting objects
    """
    @classmethod
    def _load_tabular_data_from_csv_file(cls, data_set, csv_file, class_tabular_csv_mapping, delimiter=',', user_feedback_csv_file_identifier_str=''):
        csv_reader = DictReader(csv_file, delimiter=delimiter)
        if class_tabular_csv_mapping.tabular_data_has_required_columns(csv_reader.fieldnames, user_feedback_csv_file_identifier_str):
            data_file_path = data_set.app.get_path_of_file_in_app_folder('data.csv')
            data_file = open(data_file_path, 'wb')
            csv_writer = DictWriter(data_file, fieldnames=csv_reader.fieldnames, delimiter=',')
            csv_writer.writeheader()
            csv_writer.writerows([row for row in csv_reader])
            data_file.close()
            data_table = class_tabular_csv_mapping.class_.__table__
            temp_table_name = 'csv_temp_table_{data_set_id}_{date_time}'.format(data_set_id=data_set.id, date_time=datetime.now().strftime('%H_%M_%S_%f'))
            connection = data_set.db_session.connection()

            table_columns = [Column(col.name, col.type, nullable=True) for col in data_table.columns if (col.name != 'id' and col.name != 'data_set_id')]
            temp_table = Table(
                temp_table_name,
                data_set.metadata,
                *table_columns,
                schema=data_table.schema
                # prefixes=['TEMPORARY'],  TODO: Look at whether this helps out here - looks like a standard SQLA way to do temp tables.. Would probably mean don't have to drop the table at the end perhaps.
            )

            # the following set of instructions uses subprocess to access the psql database.  This is required for the \copy command to work as it's a command line argument.
            # we are also using subprocess to create and drop the temp table, otherwise the \copy command can't see the table as it hasn't been committed to the database.
            create_query = str(CreateTable(temp_table).compile(data_set.db_session.get_bind()))
            drop_query = str(DropTable(temp_table).compile(data_set.db_session.get_bind()))
            psql_command_line_args = get_psql_command_line_args(data_set)
            subprocess.call(psql_command_line_args + ['-c', create_query])
            subprocess.call(psql_command_line_args + [
                '-c', '\COPY {schema_name}.{temp_table_name} FROM {file_path} DELIMITER {delimiter} CSV HEADER'.format(
                    temp_table_name=temp_table_name, schema_name=data_table.schema, file_path=repr(data_file_path), delimiter=repr(delimiter)
                ),
            ])
            query = """
                INSERT INTO {schema_name}.{table_name} ({columns}, data_set_id)
                SELECT {columns}, {data_set_id} FROM {schema_name}.{temp_table_name};
                """.format(
                    schema_name=data_table.schema,
                    table_name=data_table.name,
                    data_set_id=data_set.id,
                    temp_table_name=temp_table_name,
                    columns=", ".join([col.name for col in table_columns]),
            )
            connection.execute(query)
            transaction.commit() # commit required here as the connection.execute(query) is holding a lock on the temporary table, stopping us from dropping it
            subprocess.call(psql_command_line_args + ['-c', drop_query])
            remove(data_file_path)


class CsvWriterPsql(CsvWriter):
    @classmethod
    def write_tabular_data_to_csv_file_on_disk(cls, data_set, file_path, class_tabular_csv_mapping, delimiter=','):
        """Write data to a csv file in a table format. Sets first row as headings, and each subsequent row as data corresponding to the headings.

        If class_tabular_csv_mapping is not instantiated with parameters process_objects or mixins, it will be assumed that the
        ``class_tabular_csv_mapping.class_`` is a SQLA ORM class. All mixins in the database for this ``class_`` in the data_set will be written to the csv file.

        .. note:: Use a collections.OrderedDict in class_tabular_csv_mapping.attribute_column_alias to specify the output order.

        :param data_set: Data set used to store the data read in.
        :type data_set: :class:`tropofy.app.AppDataSet`
        :param file_path: Path on disk to the csv file.
        :type file_path: str
        :param class_tabular_csv_mapping: Describes how a class maps to headings and data in the csv file.
        :type class_tabular_csv_mapping: :class:`tropofy.file_io.read_write_csv.ClassTabularCsvMapping`
        :param delimiter: (Optional) csv delimiter. Default ','
        :type delimiter: str
        """
        file_as_string = cls.read_psql_table_into_csv_string(data_set, class_tabular_csv_mapping, delimiter)
        data_file = open(file_path, 'wb')
        data_file.write(file_as_string)
        data_file.close()

    @classmethod
    def read_psql_table_into_csv_string(cls, data_set, class_tabular_csv_mapping, delimiter=','):
        """Write data to a string csv file representation in a table format. Sets first row as headings, and each subsequent row as data corresponding to the headings.

        If class_tabular_csv_mapping is not instantiated with parameters process_objects or mixins, it will be assumed that the
        ``class_tabular_csv_mapping.class_`` is a SQLA ORM class. All mixins in the database for this ``class_`` in the data_set will be written to the csv file.

        .. note:: Use a collections.OrderedDict in class_tabular_csv_mapping.attribute_column_alias to specify the output order.

        :param data_set: Data set used to store the data read in.
        :type data_set: :class:`tropofy.app.AppDataSet`
        :param class_tabular_csv_mapping: Describes how a class maps to headings and data in the csv file.
        :type class_tabular_csv_mapping: :class:`tropofy.file_io.read_write_csv.ClassTabularCsvMapping`
        :param delimiter: (Optional) csv delimiter. Default ','
        :type delimiter: str
        """
        temp_file = cls._create_temporary_data_csv_from_table(data_set, class_tabular_csv_mapping, delimiter)
        csv_reader = reader(temp_file, delimiter=delimiter)
        value_rows = [row[1:-1] for row in csv_reader]
        value_rows = [",".join(row) for row in value_rows]
        field_names = ",".join(class_tabular_csv_mapping.attribute_column_aliases.values())
        temp_file.close()
        remove(temp_file.name)
        return "{field_names}\n".format(field_names=field_names) + "\n".join(value_rows)

    @classmethod
    def _create_temporary_data_csv_from_table(cls, data_set, class_tabular_csv_mapping, delimiter=','):
        data_table = class_tabular_csv_mapping.class_.__table__
        temp_file_path = data_set.app.get_path_of_file_in_app_folder('data_{time_stamp}.csv'.format(time_stamp=datetime.now().strftime('%H_%M_%S_%f')))
        connection = data_set.db_session.connection()
        psql_command_line_args = get_psql_command_line_args(data_set)
        subprocess.call(psql_command_line_args + [
            '-c', '\COPY {schema_name}.{table_name} TO {file_path} DELIMITER {delimiter} CSV'.format(
                table_name=data_table.name, schema_name=data_table.schema, file_path=repr(temp_file_path), delimiter=repr(delimiter)
            ),
        ])
        temp_file = open(temp_file_path, 'rb')
        return temp_file


def get_psql_command_line_args(data_set):
    url_object = data_set.db_session().get_bind().url
    arg_dict = {
        'db_name': url_object.database,
        'username': url_object.username,
        'host': url_object.host,
        'port': url_object.port,
    }
    args = ['psql', '-d', arg_dict['db_name'], '-U', arg_dict['username'], '-h', arg_dict['host'], '-w']
    if arg_dict['port']:
        args += ['-p', str(arg_dict['port'])]
    return args