# Copyright (c) 2014 JohnyMoSwag <johnymoswag@gmail.com>

from binascii import hexlify

import json
import logging
import os
import rsa
import shutil

from . import KeyHandlerError
from .utils import FileCrypt

log = logging.getLogger(__name__)


class KeyHandler(object):
    """KeyHanlder object is used to manage keys used for signing updates

    Kwargs:
        app (obj): Config object to get config values from
    """

    def __init__(self, app=None):
        if app:
            self.init_app(app)

    def init_app(self, obj):
        """Sets up client with config values from obj

        Args:
            obj (instance): config object
        """
        # Copies and sets all needed config attributes
        # for this object
        self.app_dir = obj.config.get('APP_DIR')
        self.app_name = obj.config.get('APP_NAME')
        self.data_dir = obj.config.get('DEV_DATA_DIR')
        self.data_dir = os.path.join(self.data_dir, 'nst-data')
        self.private_key_name = obj.config.get('PRIVATE_KEY_NAME')
        if self.private_key_name is None:
            self.private_key_name = self.app_name + '.pem'
        self.public_key_name = obj.config.get('PUBLIC_KEY_NAME')
        if self.public_key_name is None:
            self.public_key_name = self.app_name + '.pub'
        self.key_length = obj.config.get('KEY_LENGTH', 2048)
        self.version_file = os.path.join(self.data_dir, 'version.json')
        self.keys_dir = os.path.join(self.data_dir, 'keys')
        if not os.path.exists(self.keys_dir):
            log.info('Creating keys directory')
            os.makedirs(self.keys_dir)

    def make_keys(self, overwrite=False):
        """Makes public and private keys for signing and verification

        Kwargs:
            overwrite (bool): Determines if existing keys are overwritten
        """
        # Makes a set of private and public keys
        # Used for authentication
        log.info('Making keys')
        (pubkey, privkey) = rsa.newkeys(int(self.key_length))
        # This is the private key, keep this secret. You'll need
        # it to sign new updates.
        self.privkey = privkey.save_pkcs1()
        # This is the public key you must distribute with your
        # program and pass to rsa_verify.
        self.pubkey = (pubkey.n, pubkey.e)
        self._write_keys_to_file(overwrite)

    def _write_keys_to_file(self, overwrite):
        """Writes keys to disk

        Args:
            overwrite (bool): Determines if existing keys are overwritten
        """
        # Writes the public and private keys to files
        public = os.path.join(self.keys_dir, self.public_key_name)
        private = os.path.join(self.keys_dir, self.private_key_name)
        if os.path.exists(public) and os.path.exists(private):
            if not overwrite:
                log.info('Cannot overwrite old key files.')
                log.debug('Pass overwrite=True to make_keys to overwrite')
                return
            else:
                log.warning('About to overwrite old keys')
        log.info('Writing keys to file')
        with open(private, 'w') as pri:
            pri.write(self.privkey)
        fc = FileCrypt(private)
        fc.encrypt()
        with open(public, 'w') as pub:
            pub.write(str(self.pubkey))

    def sign_update(self):
        """Proxy method for :meth:`_load_private_key`, :meth:`_add_sig` &
        :meth:`_write_update_data`
        """
        # Loads private key
        # Loads version file to memory
        # Signs Version file
        # Writes version file back to disk
        self._load_private_key()
        self._add_sig()
        self._write_update_data()

    # TODO: Search deafult locations for key file
    def _find_private_key(self):
        """Checks for private key

        Returns:
            (bool) Meanings::

                True - Private key found

                False - Private key not found
        """
        # Searches keys folder for private key to sign version file
        path = os.path.join(self.keys_dir, self.private_key_name)
        log.debug('private key path: {}'.format(path))
        if os.path.exists(path + '.enc'):
            log.debug('Found private key')
            return True
        log.debug("Didn't find private key")
        return False

    def _load_private_key(self):
        """Loads private key to memory

        Raises:
            KeyHandlerError: If private key cannot be found
        """
        # Loads private key
        log.info('Loading private key')
        if not self._find_private_key():
            raise KeyHandlerError(u"You don't have any keys",
                                  expected=True)
        privkey = os.path.join(self.keys_dir, self.private_key_name)
        while 1:
            try:
                fc = FileCrypt(privkey)
                fc.decrypt()
                with open(privkey) as pk:
                    self.privkey = rsa.PrivateKey.load_pkcs1(pk.read())
                fc.encrypt()
                break
            except ValueError:
                raw_input('Wrong password... Press enter to try again')

    def _load_update_data(self):
        """Loads version file

        Returns:
            (version data)

        Raises:
            KeyHandlerError:  Cannot find version file
        """
        # Loads version file into memory
        log.info("Loading version file")
        try:
            log.debug('Version file path: {}'.format(self.version_file))
            with open(self.version_file, 'r') as f:
                update_data = json.loads(f.read())
            log.debug('Version file loaded')
            return update_data
        except Exception as e:
            log.error(e)
            raise KeyHandlerError(u'Version file not found',
                                  expected=True)

    def _add_sig(self):
        """Adds signature to version file
        """
        # Adding new signature to version file
        log.info('Adding signature to version file...')
        if not self.privkey:
            log.warning('Private key not loaded')
            raise KeyHandlerError(u'You must load your privkey first',
                                  expected=True)

        update_data = self._load_update_data()
        if 'sig' in update_data:
            log.info('Deleting sig')
            del update_data['sig']
        _data = json.dumps(update_data, sort_keys=True)
        signature = hexlify(rsa.pkcs1.sign(_data,
                            self.privkey, 'SHA-256'))

        update_data = json.loads(_data)
        update_data['sig'] = signature
        log.info('Adding sig to update data')
        self.update_data = update_data

    def _write_update_data(self):
        """Writes update data to disk

        Raises:
            KeyHandlerError: If signature wasn't added to version file
        """
        # Write version file "with new sig" to disk
        log.info('Wrote version data')
        if self.update_data:
            with open(self.version_file, 'w') as f:
                f.write(json.dumps(self.update_data, indent=4,
                        sort_keys=True))
        else:
            msg = u'You must sign update data first'
            raise KeyHandlerError(msg, expected=True)

    def copy_decrypted_private_key(self):
        privkey = os.path.join(self.keys_dir, self.private_key_name)
        fc = FileCrypt(privkey)
        fc.decrypt()
        shutil.copy(privkey, privkey + ' copy')
        fc.encrypt()

    def print_keys_to_console(self):
        """Prints public key and private key data to console"""
        print 'Private Key:\n{}\n\n'.format(self.privkey)
        print 'Public Key:\n{}\n\n'.format(self.pubkey)

    def print_public_key(self):
        """Prints public key data to console"""
        public = os.path.join(self.keys_dir, self.public_key_name)
        if os.path.exists(public):
            with open(public) as f:
                key = f.read()
            print 'Public Key:\n{}'.format(key)
        else:
            print 'No Public Key Found'

    def print_key_names_to_console(self):
        """Prints name of public and private key to console"""
        print 'Private Key:\n{}\n\n'.format(self.private_key_name)
        print 'Public Key:\n{}\n\n'.format(self.public_key_name)
