#!/usr/bin/env python
"""
Generate Python source code: random function calls with random arguments.
Use "python" command line program.

Interresting modules: all modules written in C or having some code written
in C, see Modules/*.c in Python source code.
"""

from fusil.python_tools import RUNNING_PYTHON3
from fusil.application import Application
from optparse import OptionGroup
from os.path import exists as path_exists, isabs, join as path_join
from fusil.process.watch import WatchProcess
from types import FunctionType, BuiltinFunctionType
from random import choice, randint
from fusil.bytes_generator import BytesGenerator
from fusil.unicode_generator import (
    UnicodeGenerator, IntegerGenerator, IntegerRangeGenerator,
    UnsignedGenerator, UnixPathGenerator,
    LETTERS, DECIMAL_DIGITS, UNICODE_65535, PRINTABLE_ASCII)
from fusil.process.stdout import WatchStdout
from fusil.project_agent import ProjectAgent
from fusil.process.create import CreateProcess
from inspect import ismethoddescriptor
from sys import executable, version as PYTHON_VERSION
from fusil.write_code import WriteCode
from os import listdir, stat
from stat import S_ISDIR
import re
import sys
if not RUNNING_PYTHON3:
    from sys import getfilesystemencoding

# Constants
PYTHON = executable
FILENAMES = "/etc/motd,/bin/sh"
PARSE_PROTOTYPE = True
DEBUG = False
SHOW_STDOUT = False
VERBOSE = True
STDOUT_IGNORE_REGEX = False
TIMEOUT = 30.0
PROTOTYPE_REGEX = re.compile(r"[A-Za-z]+[A-Za-z_0-9]*\(([^)]*)\)", re.MULTILINE)

MODULE_BLACKLIST = set((
    'this',         # nothing to fuzz and display lines of text
#    'logging',
#    'compileall',
))


# Functions and methods blacklist. Format:
#     module name => function names
# and module name:class name => method names
CTYPES = set((
    # Read/write arbitrary memory
    'PyObj_FromPtr', 'string_at', 'wstring_at',
    'call_function', 'call_cdeclfunction',
    'Py_INCREF', 'Py_DECREF',
    'dlsym', 'dlclose',
    '_string_at_addr', '_wstring_at_addr',

    # _ctypes.dlopen("/bin/sh", -5):
    # Inconsistency detected by ld.so: dl-open.c: 652:
    # _dl_open: Assertion `_dl_debug_initialize (0, args.nsid)->r_state == RT_CONSISTENT' failed!
    # The bug is specific to Python (not reproductible outside Python)
    "dlopen",
))

SOCKET = set((
    # Avoid DNS request (timeout)
    "gethostbyname", "gethostbyname_ex", "gethostbyaddr",
    "getnameinfo", "getaddrinfo",
))

POSIX = set((
    # exit python
    "_exit", "abort",
    # read -> timeout
    "read",
    # truncate file, remove directory, remove file
    "ftruncate", "rmdir", "unlink",
    # send a signal to the current process or a process group
    "kill", "killpg",
    # create child process
    "fork", "forkpty", "system", "popen", "popen2", "popen3", "popen4",
    # wait process exit (-> timeout)
    "wait", "wait3", "waitpid",
    # break the terminal?
    "tcsetpgrp",
))

BUILTINS = set((
    # Create huge integer, very long string or list
    "pow", "round",
))

BLACKLIST = {
    # Dangerous module: ctypes
    'ctypes': CTYPES,
    '_ctypes': CTYPES,

    # Eat lot of CPU with large arguments
    'itertools': set(("tee",)),
    'math': set(("factorial",)),
    'operator': set((
        # Create huge integer, very long string or list
        "pow", "__pow__",
        "ipow", "__ipow__",
        "mul", "__mul__",
        "repeat", "__repeat__",
    )),

    '__builtin__': BUILTINS,
    'builtins': BUILTINS,

    # Sleep
    'time': set(("sleep",)),
    'select': set(("epoll", "poll", "select")),
    'signal': set(("pause",)),

    '_socket': SOCKET,
    'socket': SOCKET,
    'posix': POSIX,
    'os': POSIX,

#     # set_authorizer: invalid callback
#     '_sqlite3:Connection': set(("set_authorizer",)),
#

    '_fileio:_FileIO': set((
        # timeout
        "read", "readall",
    )),

    '_multiprocessing:SemLock': set((
        # deadlock
        "acquire",
    )),
    '_multiprocessing:Connection': set((
        # timeout
        "recv", "recv_bytes",
    )),

    '_tkinter': set((
        # timeout
        'dooneevent',
        # no window -> no event -> loop on select()
        'mainloop',
        # issue #3880
        "_flatten",
    )),

    "termios": set((
        # tcflow(0, False) suspend output to stdout
        "tcflow",
    )),

    "dl": set((
        # dl.open("/bin/sh", -5):
        # Inconsistency detected by ld.so: dl-open.c: 652:
        # _dl_open: Assertion `_dl_debug_initialize (0, args.nsid)->r_state == RT_CONSISTENT' failed!
        # The bug is specific to Python (not reproductible outside Python)
        "open",
    )),

    "pprint": set(("_perfcheck",)),     # timeout (unlimited loop?)
    "tabnanny": set(("check",)),        # python 2.5.2 implement is just slow

    # listen to a socket and wait for requests
    "BaseHTTPServer": set(("test",)),
    "CGIHTTPServer": set(("test",)),
    "SimpleHTTPServer": set(("test",)),
    "pydoc": set(("serve",)),
}

if DEBUG:
    NB_CALL = 5
    NB_METHOD = 1
    NB_CLASS = 5
else:
    NB_CALL = 30
    NB_METHOD = 10
    NB_CLASS = 10
MAX_ARG = 6
MAX_VAR_ARG = 5

def listAllModules():
    modules = set()
    for directory in sys.path:
        if not path_exists(directory):
            continue
        for filename in listdir(directory):
            fullpath = path_join(directory, filename)
            st = stat(fullpath)
            if S_ISDIR(st.st_mode):
                if not(path_exists(path_join(fullpath, "__init__.py")) \
                or path_exists(path_join(fullpath, "__init__.pyc"))):
                    continue
                modules.add(filename)
            else:
                if filename in ("__init__.py", "__init__.pyc"):
                    continue
                for suffix in (".py", ".pyc", ".so"):
                    if filename.endswith(suffix):
                        name = filename.rsplit(".", 1)[0]
                        modules.add(name)
    return modules

class Fuzzer(Application):
    NAME = "python"

    def createFuzzerOptions(self, parser):
        options = OptionGroup(parser, "Python fuzzer")
        options.add_option("--modules",
            help="Tested Python module names separated by commas (default: test all modules)",
            type="str", default="*")
        options.add_option("--blacklist",
            help='Module blacklist separated by commas (eg. "_lsprof,_json")',
            type="str")
        options.add_option("--test-private",
            help="Test private methods (default: skip privates methods)",
            action="store_true")
        options.add_option("--timeout",
            help="Timeout in seconds (default: %.1f)" % TIMEOUT,
            type="float", default=TIMEOUT)
        options.add_option("--filenames",
            help="Names separated by commas of readable files (default: %s)" % FILENAMES,
            type="str", default=FILENAMES)
        options.add_option("--python",
            help="Python executable program path (default: %s)" % PYTHON,
            type="str", default=PYTHON)
        return options

    def setupProject(self):
        project = self.project

        project.error("Use python interpreter: %s" % self.options.python)
        version = ' -- '.join( line.strip() for line in PYTHON_VERSION.splitlines())
        project.error("Python version: %s" % version)
        PythonSource(project, self.options)
        process = PythonProcess(project, [self.options.python, '-u', '<source.py>'], timeout=self.options.timeout)
        WatchProcess(process, exitcode_score=0)

        stdout = WatchStdout(process)
        stdout.max_nb_line = None

        # Disable dummy error messages
        stdout.words = {
            'oops': 0.30,
            'bug': 0.30,
            'memory': 0.40,
            'fatal': 1.0,
            'assert': 1.0,
            'assertion': 1.0,
            'critical': 1.0,
            'panic': 1.0,
            'glibc detected': 1.0,
            'segfault': 1.0,
            'segmentation fault': 1.0,
        }

# TODO: ignore logging.fatal
# TODO: ignore logging.critical

        # CPython critical messages
        stdout.addRegex("^XXX undetected error", 1.0)

        # PyPy messages
        stdout.addRegex("Fatal RPython error", 1.0)

        if SHOW_STDOUT or DEBUG:
            stdout.show_matching = True
            stdout.show_not_matching = True

        stdout.ignoreRegex(ur"_ast.Assert()")

        # PyPy interact prompt
        # avoid false positive on "# assert did not crash"
        stdout.ignoreRegex(ur"^And now for something completely different:")

        if STDOUT_IGNORE_REGEX:
            # Invalid number of arguments
            stdout.ignoreRegex(ur"default __new__ takes no parameters")
            stdout.ignoreRegex(ur"takes no argument")
            stdout.ignoreRegex(ur"takes at least [0-9]+ argument")
            stdout.ignoreRegex(ur"takes exactly [0-9]+ argument")
            stdout.ignoreRegex(ur"takes at most [0-9]+ argument")
            stdout.ignoreRegex(ur"argument must be ")

            # Invalid argument type
            stdout.ignoreRegex(ur"\[TypeError\]")
            stdout.ignoreRegex(ur"argument [0-9]+ must be char, not ")

            # Invalid operation on an object
            stdout.ignoreRegex(ur"objects are unhashable")
            stdout.ignoreRegex(ur"operand does not support unary buffer")
            stdout.ignoreRegex(ur"unsupported operand type\(s\) for ")
            stdout.ignoreRegex(ur"'ascii' codec can't decode byte ")
            stdout.ignoreRegex(ur"object has no attribute ")
            stdout.ignoreRegex(ur"iteration over non-sequence")
            stdout.ignoreRegex(ur"long int too large to convert to int")

ERRBACK_NAME = u'errback'

METHODS_NB_ARG = {
    '__str__': 0,
    '__repr__': 0,
    '__hash__': 0,
    '__reduce__': 0,
    '__delattr__': 1,
    '__getattribute__': 1,
    '__getitem__': 1,
    '__getslice__': 2,
    '__reduce_ex__': (0, 1),
    '__getstate__': 0,
    '__setattr__': 2,
    '__setstate__': 1,
}


class PythonFuzzerError(Exception):
    pass

class PythonSource(ProjectAgent):
    def __init__(self, project, options):
        ProjectAgent.__init__(self, project, "python_source")
        self.options = options
        if self.options.modules != "*":
            self.modules = []
            for module in self.options.modules.split(","):
                module = module.strip()
                if not len(module):
                    continue
                self.modules.append(module)
        else:
            self.modules = listAllModules() - MODULE_BLACKLIST
            self.modules = list(self.modules)
        blacklist = self.options.blacklist
        if blacklist:
            blacklist = set(blacklist.split(","))
            removed =  set(self.modules) & set(blacklist)
            self.error("Blacklist modules: %s" % removed)
            self.modules = list( set(self.modules) - set(blacklist) )

        self.filenames = self.options.filenames
        if not RUNNING_PYTHON3:
            encoding = getfilesystemencoding()
            self.filenames = unicode(self.filenames, encoding)
        self.filenames = self.filenames.split(u",")
        for filename in self.filenames:
            if not isabs(filename):
                raise ValueError("Filename %r is not an absolute path! Fix the --filenames option" % filename)
            if not path_exists(filename):
                raise ValueError("File doesn't exist: %s! Use different --filenames option" % filename)
        project.error(u"Use filenames: %s" % u', '.join(self.filenames))

    def loadModule(self, module_name):
        self.module_name = module_name
        self.debug("Import %s" % self.module_name)
        self.module = __import__(self.module_name)
        for name in self.module_name.split(".")[1:]:
            self.module = getattr(self.module, name)
        try:
            self.warning("Module filename: %s" % self.module.__file__)
        except AttributeError:
            pass
        self.write = WritePythonCode(self, self.filename, self.module, self.module_name)

    def on_session_start(self):
        self.filename = self.session().createFilename(u'source.py')

        while self.modules:
            name = choice(self.modules)
            try:
                self.loadModule(name)
                break
            except Exception, err:
                self.error("Skip module %s: %s" % (name, err))
                self.modules.remove(name)
        if not self.modules:
            self.error("There is no more modules!")
            self.send('project_stop')
            return
        self.error("Test module %s" % name)
        self.send('session_rename', name)

        self.write.writeSource()
        self.send('python_source', self.filename)

class PythonProcess(CreateProcess):
    def on_python_source(self, filename):
        self.cmdline.arguments[-1] = filename
        self.createProcess()

# >'<, >"<, >\<
ESCAPE_CHARARACTERS = u"'" + u'"' + u'\\'

def formatCharacter(char):
    if char in ESCAPE_CHARARACTERS:
        # >\"<
        return u'\\' + char
    code = ord(char)
    if 32 <= code <= 126:
        # >a<
        return char
    elif code <= 255:
        # >\xEF<
        return u'\\x%02X' % code
    elif code <= 65535:
        # >\u0101<
        return u'\\u%04X' % code
    else:
        # >\U00010FA3<
        return u'\\U%08X' % code

def escapeUnicode(text):
    return ''.join( formatCharacter(char) for char in text)

class WritePythonCode(WriteCode):
    def __init__(self, parent, filename, module, module_name):
        WriteCode.__init__(self)
        self.filename = filename
        self.filenames = parent.filenames
        self.options = parent.options
        self.simple_argument_generators = (
            self.genNone,
            self.genBool,
            self.genSmallUint,
            self.genInt,
            self.genLetterDigit,
            self.genBytes,
            self.genString,
            self.genAsciiString,
            self.genUnixPath,
            self.genFloat,
            self.genExistingFilename,
            self.genErrback,
#            self.genOpenFile,
#            self.genException,
        )
        self.complex_argument_generators = self.simple_argument_generators + (
            self.genList,
            self.genTuple,
            self.genDict,
        )
        self.smallint_generator = IntegerRangeGenerator(-19, 19)
        self.int_generator = IntegerGenerator(20)
        self.bytes_generator = BytesGenerator(0, 20)
        self.unicode_generator = UnicodeGenerator(1, 20, UNICODE_65535)
        self.ascii_generator = UnicodeGenerator(0, 20, PRINTABLE_ASCII)
        self.unix_path_generator = UnixPathGenerator(100)
        self.letters_generator = UnicodeGenerator(1, 8, LETTERS | DECIMAL_DIGITS)
        self.float_int_generator = IntegerGenerator(3)
        self.float_float_generator = UnsignedGenerator(3)
        self.module = module
        self.module_name = module_name

        self.functions, self.classes = self.getFunctions()
        if not self.functions and not self.classes:
            raise PythonFuzzerError("Module %s has no function and no class!" % self.module_name)

    def writePrint(self, level, arguments):
        if RUNNING_PYTHON3:
            code = u"print (%s, file=stderr)" % arguments
        else:
            code = u"print >>stderr, %s" % arguments
        self.write(level, code)

    def writeSource(self):
        self.createFile(self.filename)
        self.write(0, u"# -*- coding: ASCII -*-")
        self.write(0, u"from gc import collect")
        self.write(0, u"from sys import stderr")
        if VERBOSE:
            self.writePrint(0, u'"import %s"' % self.module_name)
        self.write(0, u"import %s" % self.module_name)
        self.emptyLine()
        self.write(0, u"def %s(*args, **kw):" % ERRBACK_NAME)
        self.write(1, u"raise ValueError('error')")
        self.emptyLine()

        self.write(0, "def callMethod(prefix, object, name, *arguments):")
        level = self.addLevel(1)
        self.writeCallMethod()
        self.restoreLevel(level)
        self.emptyLine()

        self.write(0, "def callFunc(prefix, name, *arguments):")
        self.write(1, "return callMethod(prefix, %s, name, *arguments)" % self.module_name)
        self.emptyLine()

        self.writeCode(u'', self.module, self.functions, self.classes, 1, NB_CALL)
        self.close()

    def writeCallMethod(self):
        self.write(0, u'funcname = "%s.%%s()" %% name' % self.module_name)
        self.write(0, u'message = "[%s] %s" % (prefix, funcname)')
        self.writePrint(0, u'message')
        self.write(0, u'func = getattr(object, name)')
        self.write(0, u'try:')
        self.write(1, u'result = func(*arguments)')

        exceptions = u'(Exception, SystemExit, KeyboardInterrupt)'
        if RUNNING_PYTHON3:
            self.write(0, u'except %s as err:' % exceptions)
        else:
            self.write(0, u'except %s, err:' % exceptions)
        self.write(1, u'errmsg = repr(err)')
        if RUNNING_PYTHON3:
            self.write(1, u"errmsg = errmsg.encode('ASCII', 'replace')")
        self.writePrint(1, u'"[%s] %s => %s: %s" % (prefix, funcname, err.__class__.__name__, errmsg)')
        self.write(1, u'result = None')

        self.writePrint(0, u'"[%s] -garbage collector-" % prefix')
        self.write(0, u'collect()   # explicit call to the garbage collector')
        self.write(0, u'return result')

    def getFunctions(self):
        classes = []
        functions = []
        try:
            blacklist = BLACKLIST[self.module_name]
        except KeyError:
            blacklist = set()

        if not hasattr(self.module, "__all__"):
            names = set(dir(self.module))
        else:
            names = set(self.module.__all__)
        names -= set(("__builtins__", "__doc__", "__file__", "__name__"))
        names -= blacklist
        for name in names:
            attr = getattr(self.module, name)
            if isinstance(attr, (FunctionType, BuiltinFunctionType)):
                functions.append(name)
            elif isinstance(attr, type):
                classes.append(name)
        return functions, classes

    def getMethods(self, object, class_name):
        try:
            key = "%s:%s" % (self.module_name, class_name)
            blacklist = BLACKLIST[key]
        except KeyError:
            blacklist = set()
        methods = []
        for name in dir(object):
            if name in blacklist:
                continue
            if (not self.options.test_private) and name.startswith("__"):
                continue
            attr = getattr(object, name)
            if not ismethoddescriptor(attr):
                continue
            methods.append(name)
        return methods

    def createComplexArgument(self):
        callback = choice(self.complex_argument_generators)
        return callback()

    def createArgument(self):
        callback = choice(self.simple_argument_generators)
        value = callback()
        for item in value:
            if not isinstance(item, unicode):
                raise ValueError("%s returned type %s" % (callback, type(item)))
        return value

    def getNbArg(self, func, func_name, min_arg):
        try:
            # Known method of arguments?
            value = METHODS_NB_ARG[func_name]
            if isinstance(value, tuple):
                min_arg, max_arg = value
            else:
                min_arg = max_arg = value
            return min_arg, max_arg
        except KeyError:
            pass

        if PARSE_PROTOTYPE:
            # Try using the documentation
            args = parseDocumentation(func.__doc__, MAX_VAR_ARG)
            if args:
                return args

        return min_arg, MAX_ARG

    def callFunction(self, prefix, func_index, func_name, func, min_arg):
        min_arg, max_arg = self.getNbArg(func, func_name, min_arg)
        nb_arg = randint(min_arg, max_arg)

        if prefix:
            prefix += str(1+func_index)
            first_line = u'callMethod("%s", obj, "%s"' % (prefix, func_name)
        else:
            prefix = "f%s" % (1+func_index)
            first_line = u'callFunc("%s", "%s"' % (prefix, func_name)
        if nb_arg:
            self.write(0, first_line + u',')
            level = self.addLevel(1)
            last_char = u','
            for index in xrange(nb_arg):
                if index == nb_arg-1:
                    last_char = u')'
                self.writeArgument(1, last_char)
            self.restoreLevel(level)
        else:
            self.write(0, first_line + ')')
        self.emptyLine()

    def writeArgument(self, level, last_char=u','):
        lines = self.createComplexArgument()
        lines[-1] += last_char
        for line in lines:
            self.write(level, line)

    def useClass(self, cls_index, cls, class_name):
        nb_arg = randint(1, MAX_ARG)

        prefix = 'o%s' % (1 + cls_index)
        self.writePrint(0, u'"[%s] Create object %s"' % (prefix, 1 + cls_index))

        obj_name = u'obj'
        self.write(0, u'%s = callFunc("%s", "%s",' % (obj_name, prefix, class_name))

        for index in xrange(nb_arg):
            self.write(2, u'# argument %s/%s' % (1+index, nb_arg))
            self.writeArgument(2)
        self.write(1, u')')

        methods = self.getMethods(cls, class_name)
        if methods:
            self.write(0, u'if %s:' % obj_name)
            level = self.addLevel(1)
            self.writeCode(prefix+'m', cls, methods, tuple(), 0, NB_METHOD)
            self.write(0, u'del %s' % obj_name)
            self.writePrint(0, u'"[%s] -garbage collector -"' % prefix)
            self.write(0, u'collect()   # explicit call to the garbage collector')
            self.restoreLevel(level)
        self.emptyLine()

    def writeCode(self, prefix, object, functions, classes, func_min_arg, nb_call):
        if functions:
            for index in xrange(nb_call):
                func_name = choice(functions)
                func = getattr(object, func_name)
                self.callFunction(prefix, index, func_name, func, func_min_arg)
        if classes:
            self.nb_class = NB_CLASS
            for index in xrange(self.nb_class):
                class_name = choice(classes)
                cls = getattr(object, class_name)
                self.useClass(index, cls, class_name)

    def genNone(self):
        return [u'None']

    def genBool(self):
        if randint(0, 1) == 1:
            return [u'True']
        else:
            return [u'False']

    def genSmallUint(self):
        return [self.smallint_generator.createValue()]

    def genInt(self):
        return [self.int_generator.createValue()]

    def genBytes(self):
        # Bytes string
        bytes = self.bytes_generator.createValue()
        if RUNNING_PYTHON3:
            text = ''.join( ur"\x%02X" % byte for byte in bytes)
            text = 'b"%s"' % text
        else:
            text = u''.join( ur"\x%02X" % ord(byte) for byte in bytes)
            text = u'"%s"' % text
        return [text]

    def genUnixPath(self):
        path = self.unix_path_generator.createValue()
        return [u'"%s"' % path]

    def _genUnicode(self, generator):
        # (Unicode) character string
        text = generator.createValue()
        text = escapeUnicode(text)
        if RUNNING_PYTHON3:
            text = u'"%s"' % text
        else:
            text = u'u"%s"' % text
        return [text]

    def genLetterDigit(self):
        return self._genUnicode(self.letters_generator)

    def genString(self):
        return self._genUnicode(self.unicode_generator)

    def genAsciiString(self):
        return self._genUnicode(self.ascii_generator)

    def genFloat(self):
        int_part = self.float_int_generator.createValue()
        float_part = self.float_float_generator.createValue()
        return [u"%s.%s" % (int_part, float_part)]

    def genExistingFilename(self):
        filename = choice(self.filenames)
        return [u"'%s'" % filename]

    def genErrback(self):
        return ["%s" % ERRBACK_NAME]

    def genOpenFile(self):
        filename = choice(self.filenames)
        if RUNNING_PYTHON3:
            instr = "open('%s')" % filename
        else:
            instr = "open(u'%s')" % filename
        return [instr]

    def genException(self):
        return ["Exception('pouet')"]

    def _genList(self, open_text, close_text, empty, is_dict=False):
        # 90% of the time generate values of the same type
        same_type = (randint(0, 10) != 0)
        nb_item = randint(0, 10)
        if not nb_item:
            return [empty]
        items = []
        if same_type:
            if is_dict:
                key_callback = choice(self.simple_argument_generators)
            value_callback = choice(self.simple_argument_generators)
            for index in xrange(nb_item):
                if is_dict:
                    item = self.createDictItem(key_callback, value_callback)
                else:
                    item = value_callback()
                items.append(item)
        else:
            for index in xrange(nb_item):
                if is_dict:
                    item = self.createDictItem()
                else:
                    item = self.createArgument()
                items.append(item)
        lines = []
        for item_index, item_lines in enumerate(items):
            if item_index:
                lines[-1] += u","
                for index, line in enumerate(item_lines):
                    # Add ' ' suffix to all lines
                    item_lines[index] = u' ' + line
            lines.extend(item_lines)
        if nb_item == 1 and empty == u'tuple()':
            lines[-1] += u','
        lines[0] = open_text + lines[0]
        lines[-1] += close_text
        return lines

    def createDictItem(self, key_callback=None, value_callback=None):
        if key_callback:
            key = key_callback()
        else:
            key = self.createArgument()
        if value_callback:
            value = value_callback()
        else:
            value = self.createArgument()
        key[-1] += u": " + value[0]
        key.extend(value[1:])
        return key

    def genList(self):
        return self._genList(u'[', u']', u'[]')

    def genTuple(self):
        return self._genList(u'(', u')', u'tuple()')

    def genDict(self):
        return self._genList(u'{', u'}', u'{}', True)

def parseArguments(arguments, defaults):
    for arg in arguments.split(","):
        arg = arg.strip(" \n[]")
        if not arg:
            continue
        if "=" in arg:
            arg, value = arg.split("=", 1)
            defaults[arg] = value
        yield arg

def parsePrototype(doc):
    r"""
    >>> parsePrototype("test([x])")
    ((), None, ('x',), {})
    >>> parsePrototype('dump(obj, file, protocol=0)')
    (('obj', 'file'), None, ('protocol',), {'protocol': '0'})
    >>> parsePrototype('setitimer(which, seconds[, interval])')
    (('which', 'seconds'), None, ('interval',), {})
    >>> parsePrototype("decompress(string[, wbits[, bufsize]])")
    (('string',), None, ('wbits', 'bufsize'), {})
    >>> parsePrototype("decompress(string,\nwbits)")
    (('string', 'wbits'), None, (), {})
    >>> parsePrototype("get_referents(*objs)")
    ((), '*objs', (), {})
    >>> parsePrototype("nothing")
    """
    if not doc:
        return None
    if not isinstance(doc, (str, unicode)):
        return None
    doc = doc.strip()
    match = PROTOTYPE_REGEX.match(doc)
    if not match:
        return None
    arguments = match.group(1)
    if arguments == '...':
        return None
    defaults = {}
    vararg = None
    varkw = tuple()
    if "[" in arguments:
        arguments, varkw = arguments.split("[", 1)
        arguments = tuple(parseArguments(arguments, defaults))
        varkw = tuple(parseArguments(varkw, defaults))
    else:
        arguments = tuple(parseArguments(arguments, defaults))

    # Argument with default value? => varkw
    move = None
    for index in xrange(len(arguments)-1, -1, -1):
        arg = arguments[index]
        if arg not in defaults:
            break
        move = index
    if move is not None:
        varkw = arguments[move:] + varkw
        arguments = arguments[:move]

    if arguments and arguments[-1].startswith("*"):
        vararg = arguments[-1]
        arguments = arguments[:-1]
    return (arguments, vararg, varkw, defaults)

def parseDocumentation(doc, max_var_arg):
    """
    Arguments:
     - doc: documentation string
     - max_var_arg: maximum number of arguments for variable argument,
       eg. test(*args).
    """
    prototype = parsePrototype(doc)
    if not prototype:
        return None

    args, varargs, varkw, defaults = prototype
    min_arg = len(args)
    max_arg = min_arg + len(varkw)
    if varargs:
        max_arg += max_var_arg
    return min_arg, max_arg

if __name__ == "__main__":
    Fuzzer().main()

