# The MIT License (MIT)
#
# Copyright (c) 2014 JohnyMoSwag <johnymoswag@gmail.com>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from binascii import hexlify
import json
import logging
import os
import shutil

import Crypto
import Crypto.Signature.PKCS1_v1_5
import Crypto.Hash.SHA256

from not_so_tuf.compat import is_py3
from not_so_tuf.exceptions import KeyHandlerError
from not_so_tuf.filecrypt import FileCrypt

if Crypto is None:
    KeyHandlerError('You must have PyCrypto installed.',
                    expected=True)
if is_py3:
    long = int

log = logging.getLogger(__name__)


class KeyHandler(object):
    """KeyHanlder object is used to manage keys used for signing updates

    Kwargs:
        app (obj): Config object to get config values from
    """

    def __init__(self, app=None):
        if app:
            self.init_app(app)

    def init_app(self, obj):
        """Sets up client with config values from obj

        Args:
            obj (instance): config object
        """
        # Copies and sets all needed config attributes
        # for this object
        self.app_name = obj.config.get('APP_NAME')
        self.data_dir = obj.config.get('DEV_DATA_DIR')
        self.data_dir = os.path.join(self.data_dir, 'nst-data')

        # Private key setup
        self.private_key_name = obj.config.get('PRIVATE_KEY_NAME')
        if self.private_key_name is None:
            self.private_key_name = self.app_name + '.pem'
        if not self.private_key_name.endswith('.pem'):
            self.private_key_name += '.pem'

        # Public key setup
        self.public_key_name = obj.config.get('PUBLIC_KEY_NAME')
        if self.public_key_name is None:
            self.public_key_name = self.app_name + '.pub'
        if not self.public_key_name.endswith('.pub'):
            self.public_key_name += '.pub'

        self.key_length = obj.config.get('KEY_LENGTH', 2048)
        if self.key_length is None:
            self.key_length = 2048
        self.version_file = os.path.join(self.data_dir, 'version.json')
        self.keys_dir = os.path.join(self.data_dir, 'keys')
        if not os.path.exists(self.keys_dir):
            log.info('Creating keys directory')
            os.makedirs(self.keys_dir)

    def make_keys(self, overwrite=False, test=False):
        """Makes public and private keys for signing and verification

        Kwargs:
            overwrite (bool): Determines if existing keys are overwritten
        """
        # Makes a set of private and public keys
        # Used for authentication
        log.info('Making keys')
        rsa_key_object = Crypto.PublicKey.RSA.generate(int(self.key_length))
        # This is the private key, keep this secret. You'll need
        # it to sign new updates.

        self.privkey = rsa_key_object.exportKey(format='PEM')

        public_key_object = rsa_key_object.publickey()
        # This is the public key you must distribute with your
        # program and pass to rsa_verify.
        self.pubkey = (public_key_object.n, public_key_object.e)
        self._write_keys_to_file(overwrite, test)

    def _write_keys_to_file(self, overwrite=False, test=False):
        """Writes keys to disk

        Args:
            overwrite (bool): Determines if existing keys are overwritten
        """
        # Writes the public and private keys to files
        public = os.path.join(self.keys_dir, self.public_key_name)
        private = os.path.join(self.keys_dir, self.private_key_name)
        if os.path.exists(public) and os.path.exists(private):
            if overwrite is False:
                log.info('Cannot overwrite old key files.')
                log.debug('Pass overwrite=True to make_keys to overwrite')
                return
            else:
                log.warning('About to overwrite old keys')
        log.info('Writing keys to file')
        with open(private, 'w') as pri:
            pri.write(self.privkey)
        if test is False:
            fc = FileCrypt(private)
            fc.encrypt()
        with open(public, 'w') as pub:
            pub.write(str(self.pubkey))

    def sign_update(self, test=False):
        """Proxy method for :meth:`_load_private_key`, :meth:`_add_sig` &
        :meth:`_write_update_data`
        """
        # Loads private key
        # Loads version file to memory
        # Signs Version file
        # Writes version file back to disk
        self._load_private_key(test)
        self._add_sig()
        self._write_update_data()

    # TODO: Search deafult locations for key file
    def _find_private_key(self):
        """Checks for private key

        Returns:
            (bool) Meanings::

                True - Private key found

                False - Private key not found
        """
        # Searches keys folder for private key to sign version file
        path = os.path.join(self.keys_dir, self.private_key_name)
        log.debug('private key path: {}'.format(path))
        if os.path.exists(path + '.enc') or os.path.exists(path):
            log.debug('Found private key')
            return True
        log.debug("Didn't find private key")
        return False

    def _find_public_key(self):
        """Checks for public key

        Returns:
            (bool) Meanings::

                True - Private key found

                False - Private key not found
        """

        path = os.path.join(self.keys_dir, self.public_key_name)
        log.debug('Private key path: {}'.format(path))
        if os.path.exists(path):
            log.debug('Found public key')
            return True
        log.debug("Didn't find public key")
        return False

    def _load_private_key(self, test):
        """Loads private key to memory

        Raises:
            KeyHandlerError: If private key cannot be found
        """
        # Loads private key
        log.info('Loading private key')
        if not self._find_private_key():
            raise KeyHandlerError(u"You don't have any keys",
                                  expected=True)
        privkey = os.path.join(self.keys_dir, self.private_key_name)
        if test:
            with open(privkey) as pk:
                self.privkey = Crypto.PublicKey.RSA.importKey(pk.read())
            return

        fc = FileCrypt(privkey)
        fc.decrypt()
        with open(privkey) as pk:
            self.privkey = Crypto.PublicKey.RSA.importKey(pk.read())
        fc.encrypt()

    def get_public_key(self):
        if not self._find_public_key:
            raise KeyHandlerError(u'You do not have a public key',
                                  expected=True)
        public_key = os.path.join(self.keys_dir, self.public_key_name)
        with open(public_key) as f:
            pub_key_data = f.read()
        return KeyHandler._pub_key_string_to_tuple(pub_key_data)

    @staticmethod
    def _pub_key_string_to_tuple(x):
        x = x.replace('(', '')
        x = x.replace(')', '')
        x = x.split(',')
        return (long(x[0]), long(x[1]))

    def _load_update_data(self):
        """Loads version file

        Returns:
            (version data)

        Raises:
            KeyHandlerError:  Cannot find version file
        """
        # Loads version file into memory
        log.info("Loading version file")
        try:
            log.debug('Version file path: {}'.format(self.version_file))
            with open(self.version_file, 'r') as f:
                update_data = json.loads(f.read())
            log.debug('Version file loaded')
            return update_data
        except Exception as e:
            log.error(e)
            raise KeyHandlerError(u'Version file not found',
                                  expected=True)

    def _add_sig(self):
        """Adds signature to version file
        """
        # Adding new signature to version file
        log.info('Adding signature to version file...')
        if not self.privkey:
            log.warning('Private key not loaded')
            raise KeyHandlerError(u'You must load your privkey first',
                                  expected=True)

        update_data = self._load_update_data()
        if 'sig' in update_data:
            log.info('Deleting sig')
            del update_data['sig']
        _data = json.dumps(update_data, sort_keys=True)
        _data_hash = Crypto.Hash.SHA256.new(_data)
        signer = Crypto.Signature.PKCS1_v1_5.new(self.privkey)
        signature = signer.sign(_data_hash)

        update_data = json.loads(_data)
        update_data['sig'] = hexlify(signature)
        log.info('Adding sig to update data')
        self.update_data = update_data

    def _write_update_data(self):
        """Writes update data to disk

        Raises:
            KeyHandlerError: If signature wasn't added to version file
        """
        # Write version file "with new sig" to disk
        log.info('Wrote version data')
        if self.update_data:
            with open(self.version_file, 'w') as f:
                f.write(json.dumps(self.update_data, indent=2,
                        sort_keys=True))
        else:
            msg = u'You must sign update data first'
            raise KeyHandlerError(msg, expected=True)

    def copy_decrypted_private_key(self, fc=None):
        if fc is None:
            fc = FileCrypt()
        privkey = os.path.join(self.keys_dir, self.private_key_name)
        log.debug('Private Key Path: {}'.format(privkey))
        fc.new_file(privkey)
        fc.decrypt()
        shutil.copy(privkey, privkey + ' copy')
        fc.encrypt()

    def print_keys_to_console(self):
        """Prints public key and private key data to console"""
        print('Private Key:\n{}\n\n'.format(self.privkey))
        print('Public Key:\n{}\n\n'.format(self.pubkey))

    def print_public_key(self):
        """Prints public key data to console"""
        public = os.path.join(self.keys_dir, self.public_key_name)
        if os.path.exists(public):
            with open(public) as f:
                key = f.read()
            print('Public Key:\n{}'.format(key))
        else:
            print('No Public Key Found')

    def print_key_names_to_console(self):
        """Prints name of public and private key to console"""
        print('Private Key:\n{}\n\n'.format(self.private_key_name))
        print('Public Key:\n{}\n\n'.format(self.public_key_name))
