# Copyright (C) 2011 Versile AS
# 
# This file is part of Versile Python.
# 
# Versile Python is free software: you can redistribute it and/or
# modify it under the terms of the GNU Affero General Public License
# as published by the Free Software Foundation, either version 3 of
# the License, or (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Affero General Public License for more details.
# 
# You should have received a copy of the GNU Affero General Public
# License along with this program.  If not, see
# <http://www.gnu.org/licenses/>.
#
# Other Usage
# Alternatively, this file may be used in accordance with the terms
# and conditions contained in a signed written agreement between you
# and Versile AS.
#
# Versile Python implements Versile Platform which is a copyrighted
# specification that is not part of this software.  Modification of
# the software is subject to Versile Platform licensing, see
# https://versile.com/ for details. Distribution of unmodified
# versions released by Versile AS is not subject to Versile Platform
# licensing.
#

"""Classes for random and pseudo-random number generation."""


from os import urandom

from versile.internal import _b2s, _s2b, _vexport, _b_chr, _b_ord, _pyver
from versile.common.iface import abstract
from versile.common.util import VByteBuffer
import collections


__all__ = ['VByteGenerator', 'VCombiner', 'VConstantGenerator',
           'VHashReducer', 'VIncrementer', 'VProxyGenerator',
           'VPseudoRandomHMAC', 'VTransformer', 'VUrandom']
__all__ = _vexport(__all__)


@abstract
class VByteGenerator(object):
    """A generator for byte output.

    .. note::

        This is an abstract class that must be implemented by its
        derived classes.
    
    .. automethod:: __call__
    
    """

    def __call__(self, num_bytes):
        """See :meth:`data`\ ."""
        return self.data(num_bytes)

    @abstract
    def data(self, num_bytes):
        """Generate byte data.

        :param num_bytes: number of bytes to generate
        :type  num_bytes: int
        :returns:         generated data
        :rtype:           bytes
        
        """
        raise NotImplementedError()

    def number(self, min_num, max_num):
        """Returns an integer constructed from generated byte data.

        :param min_num: lowest integer to produce
        :type  min_num: int, long
        :param max_num: highest integer to produce
        :type  max_num: int, long
        :returns:       number with min_num <= num <= max_num
        :rtype:         int, long

        The method will internally call :meth:`data` to generate data
        to create the output number.
        
        .. warning::

            The method may need to generate multiple iterations of
            byte data until a number can be constructed. If generated
            bytes do not appear 'random' (such as e.g. a constant byte
            value generator), this method may never return.
        
        """
        if min_num == max_num:
            return min_num
        diff = max_num - min_num
        diff_hex = _s2b(hex(diff)).lstrip(b'0x').rstrip(b'L')
        if len(diff_hex) % 2:
            diff_hex = b'0' + diff_hex
        num_bytes = len(diff_hex) // 2
        # Create a mask for the left zero bytes
        left_byte = int(_b2s(b'0x' + diff_hex[:2]), 16)
        mask = 0xff
        for i in range(7, -1, -1):
            bit = 0x01 << i
            if left_byte & bit:
                break
            else:
                mask = mask ^ bit        
        while True:
            data = self(num_bytes)
            data_chars = [_s2b('0x')]
            first = True
            for d in data:
                b = _b_ord(d)
                if first:
                    b &= mask
                    first = False
                _chars = hex((b >> 4) & 0xf)[-1] + hex(b & 0xf)[-1]
                data_chars.append(_s2b(_chars))
            data_hex = _s2b('').join(data_chars)
            num = int(data_hex, 16)
            if num <= diff:
                break
        return min_num + num


class VConstantGenerator(VByteGenerator):
    """Generates a constant repeating data output.

    :param pattern: a byte pattern
    :type  pattern: bytes

    Generator output will be an infinitely recurring sequence of
    *pattern* \ .
    
    """

    def __init__(self, pattern):
        if not isinstance(pattern, bytes) or not pattern:
            raise TypeError('Pattern must be non-empty bytes object')
        self.__pattern = pattern
        self.__cache = VByteBuffer()
        
    def data(self, num_bytes):
        result = []
        bytes_left = num_bytes
        while bytes_left:
            if not self.__cache:
                self.__cache.append(self.__pattern)
            data = self.__cache.pop(bytes_left)
            result.append(data)
            bytes_left -= len(data)
        return b''.join(result)
    

class VIncrementer(VByteGenerator):
    """Generates blocks of incrementing block integer value.
    
    Generator output is a series of block data, where each block is an
    'incrementation' of the previous block.
    
    :param blocksize:   block size
    :type  blocksize:   int
    :param start_value: if not None, an initial block to generate
    :type  start_value: bytes

    If *start_value* is None, the start block will be only zero-bytes.
    
    .. automethod:: _increment
    
    """

    def __init__(self, blocksize, start_value=None):
        if blocksize <= 0:
            raise TypeError('Blocksize must be >= 1')
        self._blocksize = blocksize
        if start_value is not None:
            if (not isinstance(start_value, bytes)
                or len(start_value) != blocksize):
                raise TypeError('Start value must be blocksize bytes')
            self._block = start_value
        else:
            self._block = b'\x00'*blocksize            
        self.__cache = VByteBuffer()
        self.__cache.append(self._block)
        
    def data(self, num_bytes):
        result = []
        bytes_left = num_bytes
        while bytes_left:
            if not self.__cache:
                self._block = self._increment(self._block)
                self.__cache.append(self._block)
            data = self.__cache.pop(bytes_left)
            result.append(data)
            bytes_left -= len(data)
        return b''.join(result)

    def _increment(self, block):
        """Performs an increment of the previously generated block.

        :param block: input block
        :returns:     output block

        Default is interpret the input block as an integer and to
        increment it by '1'. Derived classes can override to implement
        a different incrementation strategy.
        
        """
        nums = [_b_ord(c) for c in block]
        for i in range((len(nums)-1), -1, -1):
            if nums[i] == 255:
                nums[i] = 0
            else:
                nums[i] += 1
                break
        else:
            nums = [0]*len(nums)
        return b''.join([_s2b(_b_chr(n)) for n in nums])
        

class VTransformer(VByteGenerator):
    """Performs block transformation applies on another generator's output.
    
    :param in_gen:    generator for transform input data
    :type  in_gen:    :class:`VByteGenerator`
    :param transform: a transform function
    :type  transform: callable
    :param blocksize: blocksize for transform input data
    :type  blocksize: int
    
    The parameter *transform* should be a callable which takes a bytes
    object as input and returns an output block of the transform.
    
    If *blocksize* is None then it is assumed that the transform
    operates on arbitrary-length inputs. Otherwise input data from
    *in_gen* is fed to the transform in blocks of *blocksize* length.
    
    """
    def __init__(self, in_gen, transform, blocksize=None):
        self.__input = in_gen
        self.__transform = transform
        self.__blocksize = blocksize
        self.__cache = VByteBuffer()
        
    def data(self, num_bytes):
        result = []
        bytes_left = num_bytes
        while bytes_left:
            if not self.__cache:
                if self.__blocksize is None:
                    data = self.__input(bytes_left)
                else:
                    data = self.__input(self.__blocksize)
                block = self.__transform(data)
                self.__cache.append(block)
            data = self.__cache.pop(bytes_left)
            result.append(data)
            bytes_left -= len(data)
        return b''.join(result)


class VHashReducer(VByteGenerator):
    """Performs hash reduction on data from another generator.
    
    Operates on a hash method with a given blocksize. Each new block
    of output data is generated by taking a number of blocks of data
    from the input generator and computing a hash digest. By tuning
    the ratio between data which is fed into the hash function and the
    output that is provided, the entropy per output byte can be
    increased. This can be useful for converting a source of low but
    uniform entropy into higher-entropy data.
    
    :param in_gen: a generator of input data for the hasher
    :type  in_gen: :class:`VByteGenerator`
    :param hasher: a hasher object
    :type  hasher: :class:`versile.crypto.VHash`
    :param ratio:  the ratio of input data read to output data produced
    :type  ratio:  int
        
    """

    def __init__(self, in_gen, hasher, ratio):
        self.__input = in_gen
        self.__hasher = hasher
        self.__digest_size = hasher.digest_size()
        self.__ratio = ratio
        self.__cache = VByteBuffer()
        
    def data(self, num_bytes):
        result = []
        bytes_left = num_bytes
        while bytes_left:
            if not self.__cache:
                data = self.__input(self.__ratio*self.__digest_size)
                self.__hasher.update(data)
                self.__cache.append(self.__hasher.digest())
            data = self.__cache.pop(bytes_left)
            result.append(data)
            bytes_left -= len(data)
        return b''.join(result)


class VCombiner(VByteGenerator):
    """Combines byte data from multiple generators with xor.

    :param in_gens: generator(s)
    :type  in_gens: :class:`VByteGenerator`
    
    """

    def __init__(self, *in_gens):
        self.__generators = in_gens

    def data(self, num_bytes):
        l = [0]*num_bytes
        for gen in self.__generators:
            data = [_b_ord(c) for c in gen(num_bytes)]
            for i in range(len(l)):
                l[i] ^= data[i]
        return b''.join([_s2b(_b_chr(n)) for n in l])
        

class VProxyGenerator(VByteGenerator):
    """Byte generator wrapper for a byte data source function.

    :param provider: provider of byte data
    :type  provider: callable
    
    *provider* must be a callable which takes a number of bytes as
    input and returns a bytes object of that length.

    """

    def __init__(self, provider):
        """Sets up with a byte generating function for the proxy.

        """
        if not isinstance(provider, collections.Callable):
            raise TypeError('Provider must be callable')
        self.__provider = provider
        
    def data(self, num_bytes):
        return self.__provider(num_bytes)
    

class VPseudoRandomHMAC(VByteGenerator):
    """Generates pseudo-random data based on :term:`HMAC`\ .

    :param hash_cls: hash class for :term:`HMAC` algorithm
    :type  hash_cls: :class:`versile.crypto.VHash`
    :param secret:   :term:`HMAC` secret
    :type  secret:   bytes
    :param seed:     PRF seed
    :type  seed:     bytes
    
    Implements the :term:`HMAC` based pseudo-random function defined
    by :rfc:`5246`\ .
    
    """

    def __init__(self, hash_cls, secret, seed):
        self._hash_cls = hash_cls
        self._secret = secret
        self._seed = seed
        self._pr_data_buff = VByteBuffer()
        self._a = seed  # A_i from RFC 5246, initialized as A_0
        
    @abstract
    def data(self, num_bytes):
        result = []
        num_left = num_bytes
        while num_left > 0:
            if self._pr_data_buff:
                d = self._pr_data_buff.pop(num_left)
                result.append(d)
                num_left -= len(d)
            else:
                # Compute next A_i
                self._a = self._hash_cls.hmac(self._secret, self._a+self._seed)
                # Generate next block of output data
                hmac = self._hash_cls.hmac(self._secret, self._a + self._seed)
                self._pr_data_buff.append(hmac)
        return b''.join(result)


class VUrandom(VByteGenerator):
    """Generates random data from :func:`os.urandom`
    
    May not be available on all platforms.
    
    """
    def data(self, num_bytes):
        if _pyver == 2:
            return _s2b(urandom(num_bytes))
        else:
            return urandom(num_bytes)
