'''
Author:      www.tropofy.com

Copyright 2013 Tropofy Pty Ltd, all rights reserved.

This source file is part of Tropofy and govered by the Tropofy terms of service
available at: http://www.tropofy.com/terms_of_service.html

This source file is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE. See the license files for details.
'''

import os
import xlrd
from tropofy.database import DBSession
from pyramid import threadlocal
from tropofy.database import read_write_xl
from sqlalchemy.orm import _mapper_registry, mapperlib, class_mapper, ColumnProperty


class DbManager():
    '''Generic database interaction'''

    @staticmethod
    def construct_object_and_add_to_db(source_class, data, data_set_id):
        """Assumes the data keys are a super set of the constructor arguments for the class type you are making
        i.e. that the grids column headers are a super set of the constructor arguments
        Note we need to make objects as the table may be for a class in an inheritance chain and so the tables
        columns may only be a subset of all the columns needed to make an object"""
        data.pop("data_set_id", None)
        data.pop("id", None)
        new_object = source_class(**data)
        new_object.data_set_id = data_set_id
        DBSession().add(new_object)
        DBSession().flush()  # I don't want to do this but think I have to
        return new_object

    @staticmethod
    def update_object_in_db(source_class, data):
        """Assumption that the names of the members of a class are equal to the keys you get in the
        data object when editing, which in turn are the names of the columns in the grid, which
        are a subset of the columns in the table hierachy for this object"""
        existing_object = DBSession().query(source_class).filter(source_class.id == data['id']).one()

        for k, v in data.iteritems():
            if k not in ['id', 'data_set_id']:
                setattr(existing_object, k, v)
        DBSession().flush()
        return existing_object

    @staticmethod
    def delete_object_from_db(source_class, obj_id):
        DBSession().delete(DBSession().query(source_class).filter(source_class.id == obj_id).one())

    @staticmethod
    def get_error_message(e):
        message = e.orig.message if hasattr(e, 'orig') else e.message
        return message.split(':')[0]

    @staticmethod
    def _find_class_mapped_to_table(table_name):
        for x in _mapper_registry.items():
            if x[0].class_ != None and \
               x[0].mapped_table != None and \
               x[0].mapped_table.schema != None and \
               x[0].mapped_table.name != None:
                if x[0].mapped_table.schema + "." + x[0].mapped_table.name == table_name:
                    return x[0].class_
        return None

    @staticmethod
    def sqla_cls_defined_column_names(cls):
        return [prop.key for prop in class_mapper(cls).iterate_properties if isinstance(prop, ColumnProperty)]
