"""
Author:      www.tropofy.com

Copyright 2013 Tropofy Pty Ltd, all rights reserved.

This source file is part of Tropofy and govered by the Tropofy terms of service
available at: http://www.tropofy.com/terms_of_service.html

This source file is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE. See the license files for details.
"""
from sqlalchemy.schema import Column
from sqlalchemy.types import TypeDecorator, Integer
from sqlalchemy.dialects.postgresql import BYTEA
from Crypto.Cipher import AES
from sqlalchemy.ext.declarative import declarative_base, declared_attr
import validation
import base
from tropofy.database import db_utils


global_aes_key = None  # Set in .ini file with encryption.global_aes_key


def aes_encrypt(data, key):
    cipher = AES.new(key)
    data = data + (" " * (16 - (len(data) % 16)))  # Data has to be multiple of certain block size. This pads it.
    return cipher.encrypt(data)

def aes_decrypt(data, key):
    cipher = AES.new(key)
    return cipher.decrypt(data).rstrip()


class EncryptedValue(TypeDecorator):
    impl = BYTEA

    def __init__(self, *args, **kwargs):
        self.aes_key = kwargs.get('aes_key', global_aes_key)
        if self.aes_key is None:
            raise Exception('No encryption key found. In .ini file set encryption.global_aes_key to a 16 char str.')
        super(EncryptedValue, self).__init__(*args, **kwargs)

    def process_bind_param(self, value, dialect):
        # 1. validate that the value passed is of the correct python_type.
        value = validation.Validator.validate(value, self.python_type)

        # 2. Cast value to str ready for encryption.
        return aes_encrypt(self._cast_value_to_str(value), self.aes_key)

    def process_result_value(self, value, dialect):
        return self._cast_str_to_value(aes_decrypt(value, self.aes_key))

    def _cast_value_to_str(self, value):
        """Can only encrypt strings. Must convert value to str before encryption."""
        return str(value)

    def _cast_str_to_value(self, str_):
        """Convert from str back to original value as passed to _cast_value_to_str."""
        return validation.Validator.validate(str_, self.python_type)

    @property
    def python_type(self):
        raise NotImplementedError('Specify the Python Type that this encrypted value maps to. Used in encryption conversions and form defaults.')


class EncryptedInteger(EncryptedValue):
    @property
    def python_type(self):
        return int


class EncryptedFloat(EncryptedValue):
    @property
    def python_type(self):
        return float


class EncryptedText(EncryptedValue):
    @property
    def python_type(self):
        return str


class ClassProperty(property):
    def __get__(self, cls, owner):
        return self.fget.__get__(None, owner)()


BaseForDecrypted = declarative_base(cls=base.ORMAncestor)


class EncryptedMixin():
    __decrypted_cls__ = None
    __base_for_decrypted__ = BaseForDecrypted

    @ClassProperty
    @classmethod
    def d(cls):
        return cls.__decrypted_cls__

    Text = EncryptedText
    Float = EncryptedFloat
    Integer = EncryptedInteger


class DecryptedMixin(BaseForDecrypted):
    __abstract__ = True
    DECRYPTED = True

    @declared_attr
    def data_set_id(cls):
        return Column(Integer, nullable=False)

    @declared_attr
    def __table_args__(cls):
        return cls.get_table_args()

    @classmethod
    def get_table_args(cls):
        return ()


def orm_decrypt_data_to_new_in_memory_db(data_set, encrypted_orm_classes, **kw):
    """
    :param data_set:
    :param encrypted_orm_classes: ORM classes that have encrypted columns. Must have __decrypted_cls__ set (otherwise matching decrypted cls is unknown).
    :returns: in_memory_db_engine
    """
    try:
        decrypted_orm_classes = [e.d for e in encrypted_orm_classes]
    except Exception as e:
        raise

    in_memory_db_engine = db_utils.create_in_memory_db(metadata=BaseForDecrypted.metadata, tables=[c.__table__ for c in decrypted_orm_classes])
    db_utils.copy_data_between_similar_dbs(
        source_db_con=data_set.db_session().get_bind().connect(),
        destination_db_con=in_memory_db_engine.connect(),
        table_matching=[(e.__table__, e.d.__table__) for e in encrypted_orm_classes],
        data_set_id=data_set.id
    )

    return in_memory_db_engine