#!/usr/bin/env python
# Walk libsmi-generated tree of MIB symbols and build pysnmp.smi
# compliant module
from exceptions import StandardError
from string import split, join, replace, find, atol, atoi
from types import StringType, DictType
import sys, time

version = '0.0.9-alpha'
genTextLoader = 1

class Error(StandardError): pass

if len(sys.argv) > 1:
    if sys.argv[1] == '--no-text':
        genTextLoader = 0
    else:
        sys.stderr.write('SNMP MIB to pysnmp objects converter, version %s.\n\
Usage:\n\
    %s [--no-text]\n\
Takes:\n\
    smidump -f python <mib-text-file.txt>\n\
program output on stdin, generates python code on stdout.\n\
The smidump tool is available at http://www.ibr.cs.tu-bs.de/projects/libsmi/\n\
The --no-text option disables code generation for MIB text comments.\n' % (version, sys.argv[0]))
        sys.exit(-1)

inputText = ''
while 1:
    c = sys.stdin.read()
    if not c:
        break
    inputText = inputText + c

if not inputText:
    sys.stderr.write('Empty input\n')
    sys.exit(-1)

codeObj = compile(inputText, '<string>', 'exec')
    
g = {}

try:
    eval(codeObj, g)
except StandardError, why:
    raise Error(
        'MIB module load error: %s' % why
        )

mib = g['MIB']

dstModName = mib['moduleName']

out = sys.stdout

__symsTable = {
    'MODULE-IDENTITY': ('ModuleIdentity',),
    'OBJECT-TYPE': ('MibScalar', 'MibTable', 'MibTableRow', 'MibTableColumn'),
    'NOTIFICATION-TYPE': ('NotificationType',),
    'TEXTUAL-CONVENTION': ('TextualConvention',),
    'MODULE-COMPLIANCE': ('ModuleCompliance',),
    'OBJECT-GROUP': ('ObjectGroup',),
    'NOTIFICATION-GROUP': ('NotificationGroup',),
    'AGENT-CAPABILITIES': ('AgentCapabilities',),
    'OBJECT-IDENTITY': ('ObjectIdentity',),
    'TRAP-TYPE': ('NotificationType',),  # smidump always uses NotificationType
    'NOTIFICATION-TYPE': ('NotificationType',)
    }

def symTrans(symbol):
    if __symsTable.has_key(symbol):
        return __symsTable[symbol]
    return symbol,

def transOpers(symbol):
    return replace(symbol, '-', '_')

def addLabelForSymbol(symbol):
    if find(symbol, '-') != -1:
        return '.setLabel(\"%s\")' % symbol
    return ''

__oidToTuple = lambda x: str(tuple(map(lambda y: int(y), split(x, '.'))))

def __reprIntVal(value):
    try:
        intVal = atoi(value)
    except ValueError:
        try:
            intVal = atol(value)
        except ValueError:
            return repr(value)
    if -2147483647 < intVal < 2147483647:
        return repr(intVal)
    else:
        return repr(long(intVal))
    
def __genDefVal(baseType, symDef):
    if baseType == 'OctetString':
        if symDef['default'][:2] == '0x':
            defVal = ''
            for i in range(2, len(symDef['default']), 2):
                defVal = defVal + chr(
                    atoi(symDef['default'][i:i+2], 16)
                    )
        else:
            defVal = symDef['default']
        return '%s' % repr(defVal)
    elif baseType == 'Integer':
        return '%s' % __reprIntVal(symDef['default'])
    elif baseType == 'Integer32':
        return '%s' % __reprIntVal(symDef['default'])
    elif baseType == 'ObjectIdentifier':
        return '%s' % __oidToTuple(symDef['default'])
    elif baseType == 'IpAddress':
        defVal = ''
        for i in range(2, len(symDef['default']), 2):
            if defVal: defVal = defVal + '.'
            defVal = defVal + str(
                atoi(symDef['default'][i:i+2], 16)
                )
        return '\"%s\"' % defVal
    elif baseType == 'Bits':
        defVal = '('
        for bit in split(replace(
            replace(replace(symDef['default'], ',', ''), '(', ''), ')', ''
            )):
            defVal = defVal + '\"%s\",' % bit
        defVal = defVal + ')'
        return defVal
    elif baseType == 'Enumeration':
        if symDef['syntax']['type'].has_key(symDef['default']):
            return '%s' % \
                   symDef['syntax']['type'][symDef['default']]['number']
        else:
            return '\"%s\"' % symDef['default']
    else:
        sys.stderr.write('WARNING: guessing DEFVAL type \'%s\' for %s\n' %
                         (symDef['default'], baseType))
        if symDef['default'][:2] == '0x':
            defVal = ''
            for i in range(2, len(symDef['default']), 2):
                defVal = defVal + chr(
                    atoi(symDef['default'][i:i+2], 16)
                    )
        else:
            defVal = symDef['default']
            try:
                atol(defVal)
            except ValueError:
                pass
            else:
                return defVal
        return '%s' % repr(defVal)

# Ugly kludge against smidump bug which does not distinguish
# size/range constraints
__kludgyStringTypes = {
    'OctetString': 1
    }

__buggySmiTypes = {
    'NetworkAddress': 'IpAddress' # this is up to smidump, but it does not care
    }

def __genTypeDef((symName, symDef), classMode=0):
    r = ''
    if classMode:
        typeDef = symDef
        identFiller = '    ';  identValue = 0
    else:
        typeDef = symDef['syntax']['type']
    if typeDef.has_key('name'):
        baseType = typeDef['name']
    if typeDef.has_key('basetype'):
        baseType = typeDef['basetype']
    if typeDef.has_key('parent module'):
        parentType = typeDef['parent module']['type']
    else:
        parentType = baseType
    # Ugly hack to overcome smidump bug in smiv1->smiv2 convertion
    if __buggySmiTypes.has_key(baseType):
        baseType = __buggySmiTypes[baseType]
    if __buggySmiTypes.has_key(parentType):
        parentType = __buggySmiTypes[parentType]
    if classMode:
        r = r + 'class %s(' % symName
        if typeDef.has_key('format'):
            r = r + '%s, ' % symTrans('TEXTUAL-CONVENTION')[0]
        identValue = identValue + 1
    if baseType in ('Enumeration', 'Bits'):
        if baseType == 'Enumeration':
            parentType = 'Integer'
        if classMode:
            r = r + '%s):\n' % parentType
            r = r + identFiller*identValue            
        else:
            r = r + ', %s()' % parentType
        if baseType == 'Enumeration':
            if classMode:
                r = r + 'subtypeSpec = %s.subtypeSpec+' % parentType
            else:
                r = r + '.subtype(subtypeSpec='
            # Python has certain limit on the number of func params
            if len(typeDef) > 127:
                r = r + 'constraint.ConstraintsUnion('
            r = r + 'constraint.SingleValueConstraint('
            cnt = 1
            for e, v in typeDef.items():
                if type(v) == DictType and v.has_key('nodetype') \
                       and v['nodetype'] == 'namednumber':
                    r = r + '%s,' % v['number']
                    if cnt % 127 == 0:
                        r = r + '), constraint.SingleValueConstraint('
                    cnt = cnt + 1
            if len(typeDef) > 127:
                r = r + ')'
            if classMode:
                r = r + ')\n'
                r = r + identFiller*identValue
            else:
                r = r + '))'
        if classMode:
            r = r + 'namedValues = namedval.NamedValues('
        else:
            r = r + '.subtype(namedValues=namedval.NamedValues('
        typedesc = typeDef.items()
        typedesc.sort(lambda x,y: cmp(x[1],y[1]))
        cnt = 1
        for e, v in typedesc:
            if type(v) == DictType and v.has_key('nodetype') \
              and v['nodetype'] == 'namednumber':
                r = r + '(\"%s\", %s), ' % (e, v['number'])
                if cnt % 127 == 0:
                    r = r + ') + namedval.NamedValues('
                cnt = cnt + 1                
        if classMode:
            r = r + ')\n'
            r = r + identFiller*identValue            
        else:
            r = r + '))'
    else:
        if classMode:
            r = r + '%s):\n' % parentType
            r = r + identFiller*identValue
        else:
            r = r + ', %s()' % parentType
        if classMode:
            if typeDef.has_key('format'):
                r = r + 'displayHint = \"%s\"\n' % typeDef['format']
                r = r + identFiller*identValue
        if __kludgyStringTypes.has_key(baseType):
            __subtypeSpec = 'constraint.ValueSizeConstraint'
        else:
            __subtypeSpec = 'constraint.ValueRangeConstraint'
 
        single_range = 0
        if typeDef.has_key('range'):
           single_range = 1
        # ATTENTION: libsmi-0.4.5 does not support "ranges". Use libsmi
        # SVN version or an older patch from Randy Couey:
        # http://www.glas.net/~ilya/download/tools/pysnmp/libsmi-0.4.5-perl_python_range_union.patch
        if typeDef.has_key('ranges'):
            # if more than one size/range is given, then we need to 
            # create a ContraintsUnion to hold all of them.
            if len(typeDef['ranges']) > 1:
                single_range = 0
                if classMode:
                    r = r + 'subtypeSpec = %s.subtypeSpec+constraint.ConstraintsUnion(' % parentType
                else:
                    r = r + '.subtype(subtypeSpec=constraint.ConstraintsUnion('
                for range in typeDef['ranges']:
                    r = r + '%s(%s,%s),' % (__subtypeSpec, __reprIntVal(range['min']), __reprIntVal(range['max']))
                if classMode:
                    r = r + ')\n'
                    r = r + identFiller*identValue
                else:
                    r = r + '))'
        # only one size/range contraint was given
        if single_range:
            if classMode:
                r = r + 'subtypeSpec = %s.subtypeSpec+%s(%s,%s)\n' % (parentType, __subtypeSpec, __reprIntVal(typeDef['range']['min']), __reprIntVal(typeDef['range']['max']))
                r = r + identFiller*identValue
                if __kludgyStringTypes.has_key(baseType) and \
                       typeDef['range']['min'] == typeDef['range']['max']:
                    r = r + 'fixedLength = %s\n' % typeDef['range']['min']
                    r = r + identFiller*identValue
            else:
                r = r + '.subtype(subtypeSpec=%s(%s, %s))' % (__subtypeSpec, __reprIntVal(typeDef['range']['min']), __reprIntVal(typeDef['range']['max']))
                if __kludgyStringTypes.has_key(baseType) and \
                       typeDef['range']['min'] == typeDef['range']['max']:
                    r = r + '.setFixedLength(%s)' % typeDef['range']['min']

    if symDef.has_key('default') and not symDef.has_key('basetype'):
        defVal = __genDefVal(baseType, symDef)
        if classMode:
            if defVal is not None:
                r = r + 'defaultValue = %s\n' % defVal
        else:
            if defVal is not None:
                r = r + '.clone(%s)' % defVal
    if classMode:
        r = r + 'pass\n\n'
        
    return r

out.write(
    "# PySNMP SMI module. Autogenerated from smidump -f python %s\n" % dstModName
    )
out.write(
    "# by libsmi2pysnmp-%s at %s,\n" % (version, time.asctime(time.localtime()))
    )
out.write("# Python version %s\n\n" % str(sys.version_info))

out.write('# Imported just in case new ASN.1 types would be created\n')
out.write('from pyasn1.type import constraint, namedval\n\n')

out.write('# Imports\n\n')

imports = {}
for imp in (
    { 'module': 'ASN1', 'name': 'Integer' },
    { 'module': 'ASN1', 'name': 'OctetString' },
    { 'module': 'ASN1', 'name': 'ObjectIdentifier' },
    { 'module': 'SNMPv2-SMI', 'name': 'Bits' }, # XXX
    { 'module': 'SNMPv2-SMI', 'name': 'Integer32' }, # libsmi bug
    { 'module': 'SNMPv2-SMI', 'name': 'TimeTicks' }, # bug in some IETF MIB
    { 'module': 'SNMPv2-SMI', 'name': 'MibIdentifier' }, # OBJECT IDENTIFIER
    ) + mib.get('imports', ()):
    if not imports.has_key(imp['module']):
        imports[imp['module']] = []
    if not imp['module']:
        sys.stderr.write('WARNING: empty MIB module name seen in smidump output at %s\n' % dstModName)
    imports[imp['module']].append(imp['name'])
map(lambda x:x.sort(), imports.values())

modNames = imports.keys(); modNames.sort()
for modName in modNames:
    out.write('( ')
    for symName in imports[modName]:
        for s in symTrans(symName):
            out.write('%s, ' % transOpers(s))
    out.write(') = mibBuilder.importSymbols(\"%s\"' % modName)
    for symName in imports[modName]:
        for s in symTrans(symName):
            out.write(', \"%s\"' % s)
    out.write(')\n')

if mib.has_key('typedefs'):
    typedefs = mib['typedefs'].items(); typedefs.sort()
else:
    typedefs = ()

if typedefs:
    out.write('\n# Types\n\n')
    for symName, symDef in typedefs:
        out.write('%s' % __genTypeDef((symName, symDef), 1))

if mib.has_key(dstModName) and mib[dstModName].has_key('identity node'):
    moduleIdentityNode = mib[dstModName]['identity node']
else:
    moduleIdentityNode = ''

if mib.has_key('nodes'):
    nodes = mib['nodes'].items()
    __oid2num = lambda o: map(lambda x: atol(x), split(o, '.'))
    nodes.sort(lambda x,y,f=__oid2num: cmp(
        f(x[1].get('oid')), f(y[1].get('oid'))
        ))
else:
    nodes = ()
    
if nodes:
    out.write('\n# Objects\n\n')            
    row_create = {}
    for symName, symDef in nodes:
        out.write('%s = ' % transOpers(symName))
        if symDef['nodetype'] == 'node':
            if symName == moduleIdentityNode:
                out.write('ModuleIdentity(%s)' % __oidToTuple(symDef['oid']))
                if mib.has_key(dstModName):
                    m = mib[dstModName]
                    if m.has_key("revisions"):
                        out.write('.setRevisions((')
                        for r in m["revisions"]:
                            out.write('\"%s\",' % r["date"])
                        out.write('))')
                out.write('%s' % addLabelForSymbol(symName))
                if genTextLoader:
                    if m.has_key('organization'):
                        out.write('\nif mibBuilder.loadTexts: %s.setOrganization("%s")' % (transOpers(symName), replace(m['organization'], '\n', '\\n')))
                    if m.has_key('contact'):
                        out.write('\nif mibBuilder.loadTexts: %s.setContactInfo("%s")' % (transOpers(symName), replace(m['contact'], '\n', '\\n')))
                    if m.has_key('description'):
                        out.write('\nif mibBuilder.loadTexts: %s.setDescription("%s")' % (transOpers(symName), replace(m['description'], '\n', '\\n')))
                out.write('\n')
                continue
            elif symDef.has_key("description"):
                out.write('ObjectIdentity(%s)' % __oidToTuple(symDef['oid']))
            else:
                out.write('MibIdentifier(%s)' % __oidToTuple(symDef['oid']))
        if symDef['nodetype'] == 'scalar':
            out.write('MibScalar(%s' % __oidToTuple(symDef['oid']))
            out.write('%s)' % __genTypeDef((symName, symDef)))
            out.write('.setMaxAccess(\"%s\")' % symDef['access'])
            if symDef.has_key('units'):
                out.write('.setUnits(\"%s\")' % symDef['units'])
        if symDef['nodetype'] == 'table':
            out.write('MibTable(%s)' % __oidToTuple(symDef['oid']))
        if symDef['nodetype'] == 'row':
            # determine if row creation is permitted, and store
            # status for later inspection by column nodes.
            if symDef.has_key('create'):
                row_create[symDef['oid']] = symDef['create']
            else:
                row_create[symDef['oid']] = 'false'
            out.write('MibTableRow(%s)' % __oidToTuple(symDef['oid']))
            if symDef['linkage'] and type(symDef['linkage'][0]) == StringType:
                out.write('.setIndexNames(')
                cnt = 0
                for idx in symDef['linkage']:
                    if cnt:
                        out.write(', ')
                    else:
                        cnt = cnt + 1
                    # smidump does not distinguish outer/inner indices
                    for _modName, _symNames in imports.items():
                        for _symName in _symNames:
                            if _symName == idx:
                                modName = _modName
                                break
                        else:
                            continue
                        break
                    else:
                        modName = dstModName
                    if idx == symDef['linkage'][-1] and \
                       symDef.has_key("implied") and \
                       symDef["implied"] == "true":
                        impliedFlag = 1
                    else:
                        impliedFlag = 0
                        
                    out.write('(%d, \"%s\", \"%s\")' % (
                        impliedFlag, modName, idx
                        ))
                out.write(')')
        if symDef['nodetype'] == 'column':
            out.write('MibTableColumn(%s' % __oidToTuple(symDef['oid']))
            out.write('%s)' % __genTypeDef((symName, symDef)))
            # smidump does not tag columns as read-create.
            # we must check the parent row object to determine if column is
            # createable
            parent = join(split(symDef['oid'], '.')[:-1], '.')
            if row_create[parent] == 'true' and symDef['access']=='readwrite':
                out.write('.setMaxAccess(\"%s\")' % 'readcreate')
            else:
                out.write('.setMaxAccess(\"%s\")' % symDef['access'])

        out.write('%s' % addLabelForSymbol(symName))

        if genTextLoader:
            if symDef.has_key("description"):        
                out.write('\nif mibBuilder.loadTexts: %s.setDescription("%s")' % (transOpers(symName), replace(symDef['description'], '\n', '\\n')))

        out.write('\n')
        
    out.write('\n# Augmentions\n')
    for symName, symDef in mib['nodes'].items():
        if symDef['nodetype'] == 'row':
            if symDef['linkage'] and type(symDef['linkage'][0]) == DictType:
                for idx in symDef['linkage']:
                    for m, indices in idx.items():
                        if m != dstModName:
                            out.write(
                                '%s, = mibBuilder.importSymbols(\"%s\", \"%s\")\n'%(
                                transOpers(indices['relatedNode']), m,
                                indices['relatedNode']
                                ))
                        out.write(
                            '%s.registerAugmentions((\"%s\", \"%s\"))\n' % (
                            indices['relatedNode'], dstModName, symName
                            ))
                        out.write(
                            'apply(%s.setIndexNames, %s.getIndexNames())\n' % (
                            symName, transOpers(indices['relatedNode'])
                            ))

if mib.has_key('notifications'):
    notifications = mib['notifications'].items()
    __oid2num = lambda o: map(lambda x: atol(x), split(o, '.'))
    nodes.sort(lambda x,y,f=__oid2num: cmp(
        f(x[1].get('oid')), f(y[1].get('oid'))
        ))
else:
    notifications = ()

if notifications:
    out.write('\n# Notifications\n\n')
    for symName, symDef in notifications:
        out.write('%s = ' % transOpers(symName))
        if symDef['nodetype'] == 'notification':
            out.write('NotificationType(%s)'  % __oidToTuple(symDef['oid']))
            out.write('.setObjects(')
            for objName, objDef in symDef['objects'].items():
                out.write('(\"%s\", \"%s\"), ' % (objDef['module'], objName))
            out.write(')')
            out.write('%s' % addLabelForSymbol(symName))
        out.write('\n')

if mib.has_key('groups'):
    groups = mib['groups'].items()
    __oid2num = lambda o: map(lambda x: atol(x), split(o, '.'))
    nodes.sort(lambda x,y,f=__oid2num: cmp(
        f(x[1].get('oid')), f(y[1].get('oid'))
        ))
else:
    groups = ()

if groups:
    out.write('\n# Groups\n\n')
    for symName, symDef in groups:
        out.write('%s = ' % transOpers(symName))
        if symDef['nodetype'] == 'group':
            if mib.has_key('notifications') and mib['notifications'].has_key(symDef['members'].keys()[0]):
                out.write('NotificationGroup(')
            else:
                out.write('ObjectGroup(')
            out.write('%s)'  % __oidToTuple(symDef['oid']))
            out.write('.setObjects(')
            for objName, objDef in symDef['members'].items():
                out.write('(\"%s\", \"%s\"), ' % (objDef['module'], objName))
            out.write(')')
            out.write('%s' % addLabelForSymbol(symName))
        out.write('\n')

out.write('\n# Exports\n\n')

if moduleIdentityNode:
    out.write('# Module identity\n')
    out.write('mibBuilder.exportSymbols(\"%s\"' % dstModName)    
    out.write(', PYSNMP_MODULE_ID=%s' % transOpers(moduleIdentityNode))
    out.write(')\n\n')
    
if typedefs:
    out.write('# Types\n')
    out.write('mibBuilder.exportSymbols(\"%s\"' % dstModName)
    idx = 1L
    for symName, symObj in typedefs:
        if idx % 127 == 0:
            out.write(')\n')
            out.write('mibBuilder.exportSymbols(\"%s\"' % dstModName)
        out.write(', %s=%s' % (symName, symName))
        idx = idx + 1
    out.write(')\n\n')
    
if nodes:
    out.write('# Objects\n')
    out.write('mibBuilder.exportSymbols(\"%s\"' % dstModName)        
    idx = 1L
    for symName, symObj in nodes:
        if idx % 127 == 0:
            out.write(')\n')
            out.write('mibBuilder.exportSymbols(\"%s\"' % dstModName)
        out.write(', %s=%s' % ((transOpers(symTrans(symName)[0]),)*2))
        idx = idx + 1
    out.write(')\n\n')
    
if notifications:
    out.write('# Notifications\n')
    out.write('mibBuilder.exportSymbols(\"%s\"' % dstModName)
    idx = 1L
    for symName, symObj in notifications:
        if idx % 127 == 0:
            out.write(')\n')
            out.write('mibBuilder.exportSymbols(\"%s\"' % dstModName)
        out.write(', %s=%s' % ((transOpers(symName),)*2))
        idx = idx + 1
    out.write(')\n\n')

if groups:
    out.write('# Groups\n')
    out.write('mibBuilder.exportSymbols(\"%s\"' % dstModName)
    idx = 1L
    for symName, symObj in groups:
        if idx % 127 == 0:
            out.write(')\n')
            out.write('mibBuilder.exportSymbols(\"%s\"' % dstModName)
        out.write(', %s=%s' % ((transOpers(symName),)*2))
        idx = idx + 1        
    out.write(')\n')    

# XXX
# implement API version checking
