#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
gnng - Fetches network devices interfaces and displays them in a table view.

Fetches interface information from routing and firewall devices. This includes
network and IP information along with the inbound and outbound filters that
may be applied to the interface. Works on Juniper, Netscreen, Foundry, and Cisco
devices.
"""

__author__ = 'Jathan McCollum, Mark Ellzey Thomas'
__maintainer__ = 'Jathan McCollum'
__email__ = 'jmccollum@salesforce.com'
__copyright__ = 'Copyright 2003-2013, AOL Inc.; 2013 Salesforce.com'
__version__ = '1.3.1'

from collections import namedtuple
import csv
import cStringIO
import math
import operator
from optparse import OptionParser
import os
import re
from sqlite3 import dbapi2 as sqlite
import sys
from twisted.python import log
from trigger.cmds import NetACLInfo
from trigger.netdevices import NetDevices

# Put this here until the default changes to not load ACLs from redis.
from trigger.conf import settings
settings.WITH_ACLS = False

#log.startLogging(sys.stdout, setStdout=False)

# Constants
DEBUG = os.getenv('DEBUG')
MAX_CONNS = 10
ROW_LABELS = ('Interface', 'Addresses', 'Subnets', 'ACLs IN', 'ACLs OUT',
              'Description')

# Namedtuples
RowData = namedtuple('RowData', 'all_rows subnet_table')
DottyData = namedtuple('DottyData', 'graph links')

# Functions
def parse_args(argv):
    parser = OptionParser(usage='%prog [options] [routers]', description='''GetNets-NG

Fetches interface information from routing and firewall devices. This includes
network and IP information along with the inbound and outbound filters that
may be applied to the interface. Skips un-numbered and disabled interfaces by
default. Works on Cisco, Foundry, Juniper, and NetScreen devices.''')
    parser.add_option('-a', '--all', action='store_true',
                      help='run on all devices')
    parser.add_option('-c', '--csv', action='store_true',
                      help='output the data in CSV format instead.')
    parser.add_option('-d', '--include-disabled', action='store_true',
                      help='include disabled interfaces.')
    parser.add_option('-u', '--include-unnumbered', action='store_true',
                      help='include un-numbered interfaces.')
    parser.add_option('-j', '--jobs', type='int', default=MAX_CONNS,
                      help='maximum simultaneous connections to maintain.')
    parser.add_option('-N', '--nonprod', action='store_false', default=True,
                      help='Look for production and non-production devices.')
    parser.add_option('-s', '--sqldb', type='str',
                      help='output to SQLite DB')
    parser.add_option('', '--dotty', action='store_true',
                      help='output connect-to information in dotty format.')
    parser.add_option('', '--filter-on-group', action='append',
                      help='Run on all devices owned by this group')
    parser.add_option('', '--filter-on-type', action='append',
                      help='Run on all devices with this device type')

    opts, args = parser.parse_args(argv)

    if len(args) == 1 and not opts.all:
        parser.print_help()
        sys.exit(1)

    return opts, args

def fetch_router_list(args):
    """Turns a list of device names into device objects, skipping unsupported,
    invalid, or filtered devices."""
    nd = NetDevices(production_only=opts.nonprod)
    ret = []
    blocked_groups = []
    if args:
        for arg in args:
            if not pass_filters(nd.find(arg)):
                continue
            ret.append(nd.find(arg))

    else:
        for entry in nd.itervalues():
            if entry.owningTeam in blocked_groups:
                continue
            if entry.vendor in ('cisco', 'foundry', 'juniper'):
                if 'oob' in entry.shortName:
                    continue

                if not pass_filters(entry):
                    continue
                ret.append(entry)

    return sorted(ret, reverse=True)

def pass_filters(device):
    """Used by fetch_router_list() to filter a device based on command-line arguments."""
    if opts.filter_on_group:
        if device.owningTeam not in opts.filter_on_group:
            return False
    if opts.filter_on_type:
        if device.deviceType not in opts.filter_on_type:
            return False

    return True

def indent(rows, hasHeader=False, headerChar='-', delim=' | ', justify='left',
           separateRows=False, prefix='', postfix='', wrapfunc=lambda x:x, wraplast=True):
    """Indents a table by column.
       - rows: A sequence of sequences of items, one sequence per row.
       - hasHeader: True if the first row consists of the columns' names.
       - headerChar: Character to be used for the row separator line
         (if hasHeader==True or separateRows==True).
       - delim: The column delimiter.
       - justify: Determines how are data justified in their column.
         Valid values are 'left','right' and 'center'.
       - separateRows: True if rows are to be separated by a line
         of 'headerChar's.
       - prefix: A string prepended to each printed row.
       - postfix: A string appended to each printed row.
       - wrapfunc: A function f(text) for wrapping text; each element in
         the table is first wrapped by this function."""
    # closure for breaking logical rows to physical, using wrapfunc
    def rowWrapper(row):
        if not wraplast:
            lastcolumn = row[-1]
            newRows = [wrapfunc(item).split('\n') for item in row[0:-1]]
            newRows.append([lastcolumn])
        else:
            newRows = [wrapfunc(item).split('\n') for item in row]

        return [[substr or '' for substr in item] for item in map(None,*newRows)]
    # break each logical row into one or more physical ones
    logicalRows = [rowWrapper(row) for row in rows]
    # columns of physical rows
    columns = map(None,*reduce(operator.add,logicalRows))
    # get the maximum of each column by the string length of its items
    maxWidths = [max([len(str(item)) for item in column]) for column in columns]
    rowSeparator = headerChar * (len(prefix) + len(postfix) + sum(maxWidths) + \
                                 len(delim)*(len(maxWidths)-1))
    # select the appropriate justify method
    justify = {'center':str.center, 'right':str.rjust, 'left':str.ljust}[justify.lower()]
    output=cStringIO.StringIO()
    if separateRows: print >> output, rowSeparator
    for physicalRows in logicalRows:
        for row in physicalRows:
            print >> output, \
                prefix \
                + delim.join([justify(str(item),width) for (item,width) in zip(row,maxWidths)]) \
                + postfix
        if separateRows or hasHeader: print >> output, rowSeparator; hasHeader=False
    return output.getvalue()

def wrap_onspace(text, width):
    """
    A word-wrap function that preserves existing line breaks
    and most spaces in the text. Expects that existing line
    breaks are posix newlines (\n).

    Written by Mike Brown
    http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/148061
    """
    return reduce(lambda line, word, width=width: '%s%s%s' %
                  (line,
                   ' \n'[(len(line[line.rfind('\n')+1:])
                         + len(word.split('\n',1)[0]
                              ) >= width)],
                   word),
                  text.split(' ')
                 )

def wrap_onspace_strict(text, width):
    """Similar to wrap_onspace, but enforces the width constraint:
       words longer than width are split."""
    wordRegex = re.compile(r'\S{'+str(width)+r',}')
    return wrap_onspace(wordRegex.sub(lambda m: wrap_always(m.group(),width),text),width)

def wrap_always(text, width):
    """A simple word-wrap function that wraps text on exactly width characters.
       It doesn't split the text in words."""
    return '\n'.join([ text[width*i:width*(i+1)] \
                       for i in xrange(int(math.ceil(1.*len(text)/width))) ])

def write_sqldb(sqlfile, dev, rows):
    """Write device fields to sqlite db"""
    create_table = False

    if not os.path.isfile(sqlfile):
        create_table = True

    connection = sqlite.connect(sqlfile)
    cursor = connection.cursor()

    if create_table:
        # if the db doesn't exist we want to create the table.
        cursor.execute('''
        CREATE TABLE dev_nets (
            id            INTEGER PRIMARY KEY,
            insert_date   DATE,
            device_name   VARCHAR(128),
            iface_name    VARCHAR(32),
            iface_addrs   VARCHAR(1024),
            iface_subnets VARCHAR(1024),
            iface_inacl   VARCHAR(32),
            iface_outacl  VARCHAR(32),
            iface_descr   VARCHAR(1024)
        );
        ''')
        cursor.execute('''
        CREATE TRIGGER auto_date AFTER INSERT ON dev_nets
        BEGIN
            UPDATE dev_nets SET insert_date = DATETIME('NOW')
                WHERE rowid = new.rowid;
        END;
        ''')

    for row in rows:
        iface, addrs, snets, inacl, outacl, desc = row
        cursor.execute('''
            INSERT INTO dev_nets (
                device_name,
                iface_name,
                iface_addrs,
                iface_subnets,
                iface_inacl,
                iface_outacl,
                iface_descr )
            VALUES (
                '%s', '%s', '%s',
                '%s', '%s', '%s', '%s'
            );''' % (
                dev, iface, addrs,
                snets, inacl, outacl, desc )
        )

    connection.commit()
    cursor.close()
    connection.close()

def get_interface_data(devices, production_only=True, max_conns=MAX_CONNS):
    """
    Fetch interface information from ``devices`` and return it as a dict.

    :param devices:
        List of device hostnames

    :param production_only:
        Whether to include only devices marked as "PRODUCTION"

    :param max_conns:
        Max number of simultaneous connections
    """
    skip_disabled = not opts.include_disabled # Inverse of include is skip :D
    ninfo = NetACLInfo(devices=devices, production_only=production_only,
                       max_conns=max_conns,
                       skip_disabled=skip_disabled)
    ninfo.run()
    if DEBUG:
        print 'NetACLInfo done!'

    return ninfo.config

def build_output(main_data, opts, labels=None):
    """
    Iterate the interface data, then build and return row data.

    :param main_data:
        Dictionary of interface data

    :param opts:
        OptionParser object

    :param labels:
        Row labels for table output
    """
    if labels is None:
        labels = ROW_LABELS

    subnet_table = {}
    all_rows = []
    for dev, data in main_data.iteritems():
        rows = []
        if not opts.csv and not opts.dotty:
            print "DEVICE: %s" % dev

        interfaces = sorted(data)
        for interface in interfaces:
            iface = data[interface]

            # Maybe skip down interfaces
            if 'addr' not in iface and not opts.include_disabled:
                continue

            if DEBUG:
                print '>>> ', interface

            addrs   = iface['addr']
            subns   = iface['subnets']
            acls_in  = iface['acl_in']
            acls_out = iface['acl_out']
            desctext = ' '.join(iface.get('description')).replace(' : ', ':')

            # Maybe skip un-numbered interfaces
            if not addrs and not opts.include_unnumbered:
                continue

            # Trim the description
            if not opts.csv:
                desctext = desctext[0:50]

            addresses = []
            subnets   = []

            for a in addrs:
                addresses.append(a.strNormal())

            for s in subns:
                subnets.append(s.strNormal())

                if s in subnet_table:
                    subnet_table[s].append((dev, interface, addrs))
                else:
                    subnet_table[s] = [(dev, interface, addrs)]

            if DEBUG:
                print '\t in:', acls_in
                print '\t ou:', acls_out
            rows.append([interface, ' '.join(addresses),
             ' '.join(subnets), '\n'.join(acls_in), '\n'.join(acls_out), desctext])

        all_rows.append(rows)

    return RowData(all_rows, subnet_table)

def handle_output(all_rows, opts):
    """
    Do stuff with the output data.

    :param all_rows:
        A list of lists of row data

    :param opts:
        OptionParser object
    """
    for rows in all_rows:

        if opts.csv:
            writer = csv.writer(sys.stdout)
            for row in rows:
                writer.writerow([dev] + row)
        elif opts.dotty:
            continue
        elif opts.sqldb:
            write_sqldb(opts.sqldb, dev, rows)
        else:
            print_table(rows)

def print_table(rows, labels=None, has_header=True, separate_rows=False,
                wrapfunc=None, delim=' | ', wrap_last=False):
    """
    Print the interface table for a device
    """
    if labels is None:
        labels = ROW_LABELS
    if wrapfunc is None:
        wrapfunc = lambda x: wrap_onspace(x, 20)

    print indent([labels] + rows, hasHeader=has_header,
                 separateRows=separate_rows, wrapfunc=wrapfunc, delim=delim,
                 wraplast=wrap_last)

def output_dotty(subnet_table, display=True):
    """
    Output and return dotty config for a ``subnet_table``

    :param subnet_table:
        Dict mapping subnets to devices and interfaces
    """
    links = {}

    for ip, devs in subnet_table.iteritems():
        if len(devs) > 1:
            router1 = devs[0][0]
            router2 = devs[1][0]

            kf1 = router1 in links
            kf2 = router2 in links

            if kf1:
                if router2 not in links[router1]:
                    links[router1].append(router2)

            elif kf2:
                if router1 not in links[router2]:
                    links[router2].append(router1)

            else:
                links[router1] = [router2]

    if not links:
        print 'No valid links for dotty generation.'
        return None

    nd = NetDevices() # This uses the pre-existing NetDevices singleton

    graph = '''graph network {
    overlap=scale; center=true; orientation=land;
    resolution=0.10; rankdir=LR; ratio=fill;
    node [fontname=Courier, fontsize=10]'''

    for leaf, subleaves in links.iteritems():
        for subleaf in subleaves:
            graph += '"%s"--"%s"\n' % (leaf.shortName, subleaf.shortName)
        #print >>sys.stderr, leaf,"connects to: ",','.join(subleaves)
    graph += '\n}'

    if display:
        print graph

    return DottyData(graph, links)

def main():

    if opts.all:
        routers = fetch_router_list(None)
    else:
        routers = fetch_router_list(args[1:])

    if not routers:
        sys.exit(1)

    main_data = get_interface_data(devices=routers, production_only=opts.nonprod)
    all_rows, subnet_table = build_output(main_data, opts)
    handle_output(all_rows, opts)

    if opts.dotty:
        output_dotty(subnet_table)

if __name__ == '__main__':
    global opts
    opts, args = parse_args(sys.argv)

    main()
