# The MIT License (MIT)
#
# Copyright (c) 2014 JohnyMoSwag <johnymoswag@gmail.com>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from base64 import urlsafe_b64encode
import getpass
import hashlib
import logging
import os
import sys
import time

from cryptography.hazmat.backends import openssl
from cryptography.fernet import Fernet

from not_so_tuf.exceptions import FileCryptError


log = logging.getLogger(__name__)
log.addHandler(logging.NullHandler())


# Shortcut to test different backends
my_backend = openssl.backend


class FileCrypt(object):
    """Small wrapper around cryptography to make it easier to use
    with not-so-tuf.

    Args:
        filename (str): The name of the file to encrypt

        password_timeout (int): The number of seconds before
        needing to reenter password. DEFAULT is 30.

        max_tries (int): The number of password attempts before
        program exists.  DEFAULT is 2
    """

    def __init__(self, filename=None, password_timeout=30, max_tries=2):
        self.password = None
        self.password_timer = 0
        self.password_max_tries = max_tries
        self.passwrod_timeout = password_timeout
        self.test = False
        self.new_file(filename)

    def new_file(self, filename=None):
        """Adds filename internally to be used for encryption and
        decryption. Also adds .enc to filename to be used  as
        encrypted filename.

        Args:

            filename (str): Path of file to be encrypted/decrypted
        """
        if filename is not None:
            self.filename, self.enc_filename = self._set_filenames(filename)
            self.enc_filename = filename + '.enc'
        else:
            log.warning('No file to process yet.')

    def encrypt(self):
        """Will encrypt the file"""
        if self.filename is None:
            raise FileCryptError('Must set filename with new_file '
                                 'method call before you can encrypt',
                                 expected=True)

        if not os.path.exists(self.filename):
            raise FileCryptError('No file to encrypt.')

        with open(self.filename, 'r') as f:
            plain_data = f.read()
            log.debug('Got plain text')

        log.debug('Lets start this encryption process.')
        if self.password is None:
            self._get_password()

        f = Fernet(self.password, backend=my_backend)
        enc_data = f.encrypt(plain_data)

        with open(self.enc_filename, 'w') as f:
            f.write(enc_data)
            log.debug('Wrote encrypted '
                      'data to {}'.format(self.enc_filename))
        os.remove(self.filename)
        log.debug('Removed original file')
        self._del_internal_password()

    def decrypt(self):
        """Will decrypt the file"""
        if self.filename is None:
            raise FileCryptError('Must set filename with new_file '
                                 'method call before you can decrypt',
                                 expected=True)

        if not os.path.exists(self.enc_filename):
            raise FileCryptError('No encrypted file to decrypt')

        with open(self.enc_filename, 'r') as f:
            log.debug('Grabbing ciphertext.')
            enc_data = f.read()

        plain_data = None
        tries = 0
        while tries < self.password_max_tries:
            log.debug('Tries = {}'.format(tries))
            if self.password is None:
                self._get_password()
            try:
                log.debug('Going to attempt to decrypt this ish')
                f = Fernet(self.password, backend=my_backend)
                plain_data = f.decrypt(enc_data)
                break
            except Exception as e:
                self.password = None
                raw_input('\nInvalid Password.  Press enter to try again')
                log.warning('Invalid Password')
                log.error(str(e), exc_info=True)
                tries += 1

        if plain_data is not None:
            log.debug('Writing plaintext to file.')
            with open(self.filename, 'w') as f:
                f.write(plain_data)
            log.debug('Done writing to file.')
        else:
            log.warning('You entered to many wrong passwords.')
            sys.exit(0)

    @staticmethod
    def gen_pass(x):
        big_key = hashlib.sha256(x).hexdigest()
        key = big_key[:32]
        return urlsafe_b64encode(key)

    def _set_filenames(self, filename):
        # Helper function to correctly set filename and
        # enc_filename instance attributes
        if filename.endswith('.enc'):
            filename_ = filename[:-4]
            enc_filename = filename
        else:
            filename_ = filename
            enc_filename = filename + '.enc'
        return filename_, enc_filename

    def _get_password(self):
        # Gets user password without echoing to the console
        log.debug('Getting user password')
        pass_ = getpass.getpass('Enter password:\n-->')
        self.password = FileCrypt.gen_pass(pass_)
        log.debug('Got you pass')
        self._update_timer()

    def _update_timer(self):
        # Updates internal timer if not already past current time
        if self.password_timer < time.time():
            log.debug('Updating your internal timer.')
            self.password_timer = time.time() + float(self.passwrod_timeout)

    def _del_internal_password(self):
        # Deletes user password once its not needed.
        # i.e. when the file been encrypted or timer expired
        if self.password_timer < time.time():
            log.debug('About to delete internal password')
            self.password = None


if __name__ == '__main__':
    log = logging.getLogger(__name__)
    log.setLevel(logging.DEBUG)
    s = logging.StreamHandler()
    s.setLevel(logging.DEBUG)
    log.addHandler(s)
