import wtforms
from cgi import FieldStorage
from wtforms.validators import *
from wtforms.validators import StopValidation, ValidationError

__all__ = wtforms.validators.__all__ + ('StopValidation', 'ValidationError', 'FileRequired', 'FileAllowed')

class FileRequired(InputRequired):
    '''
    Check if the field is valid file entity. This class could be used for single file / multi files field.
    '''

    def _check_fieldstorage(self, data):
        '''
        :param data: Could be instance of cgi.FieldStorage, or a list which contains instances of cgi.FieldStorage.
        '''
        if isinstance(data, list) and len(data) > 0:
            for each_data in data:
                if not isinstance(each_data, FieldStorage):
                    return False
            return True
        elif isinstance(data, FieldStorage):
            return True
        return False

    def __call__(self, form, field):
        if not self._check_fieldstorage(field.data):
            if self.message is None:
                message = field.gettext('This field is required for file(s).')
            else:
                message = self.message

            field.errors[:] = []
            raise StopValidation(message)

class FileAllowed:
    '''
    Check if the uploaded file(s) are valid mimetype(s).
    '''

    def __init__(self, allowed_types, message=None):
        '''
        :param allowed_types: A list/tuple of extension names, ex. ['jpg', 'png']
        '''
        self.allowed_types = [each_type.lower() for each_type in allowed_types]
        if message is None:
            msg = 'Only these types are allowed: ' + ', '.join(allowed_types)
        else:
            self.message = message

    def _check_allowed_types(self, data):
        '''
        :param data: Could be instance of cgi.FieldStorage, or a list which contains instances of cgi.FieldStorage.
        '''
        pass_flag = True
        if isinstance(data, list):
            for each_data in data:
                if each_data.filename.split('.')[-1].lower() not in self.allowed_types:
                    pass_flag = False
                    break
        elif isinstance(data, FieldStorage):
            if data.filename.split('.')[-1] not in self.allowed_types:
                pass_flag = False
        else:
            pass_flag = False
        return pass_flag

    def __call__(self, form, field):
        if not self._check_allowed_types(field.data):
            raise ValidationError(self.message)
