#!/usr/bin/env python
#
# create-nodes.py: create undercloud nodes for Openstack
#
# Copyright (C) 2014 Ravello Systems, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import sys
import json
import logging
import errno
import six
import copy
import socket

from argparse import ArgumentParser
from getpass import getpass
from fnmatch import fnmatch
from collections import namedtuple
from six.moves import reduce

from ravello_sdk import RavelloClient


magic_svm_cpuids = [
    { "index": "0", "value": "0000000768747541444d416369746e65" },
    { "index": "1", "value": "000006fb00000800c0802000078bfbfd" },
    { "index": "8000000a", "value": "00000001000000400000000000000088" },
    { "index": "80000000", "value": "8000000a000000000000000000000000" },
    { "index": "80000001", "value": "00000000000000000000001520100800" }, ]

mac_base = '2c:c2:60:00:00:00'


def create_parser():
    """Create the command-line parser."""
    parser = ArgumentParser(description='Create a number of nodes in Ravello')
    parser.add_argument('-d', '--debug', action='store_true', help='enable debugging mode')
    parser.add_argument('-u', '--username', help='Ravello API username')
    parser.add_argument('-p', '--password', help='Ravello API password ($RAVELLO_PASSWORD)')
    parser.add_argument('-f', '--flavor', help='install this OS flavor', default='empty',
                        choices=[flavor.name for flavor in flavors])
    parser.add_argument('-c', '--cpus', type=int, default=1, help='number of CPUs')
    parser.add_argument('-m', '--memory', type=int, default=2048, help='node memory (MB)')
    parser.add_argument('-s', '--disk', type=int, default=50, help='root disk size (GB)')
    parser.add_argument('--nic-model', help='NIC device model', default='virtio',
                        choices=('virtio', 'e1000'))
    parser.add_argument('--disk-model', help='disk device model', default='virtio',
                        choices=('virtio', 'ide'))
    parser.add_argument('--keypair', help='inject this keypair')
    parser.add_argument('--network', help='define a network ("dhcp" or network/bits)',
                        action='append')
    parser.add_argument('--service', help='expose IP service on PORT on head node',
                        metavar='PORT', action='append')
    parser.add_argument('--public-ip', help='use public IP for head node', action='store_true')
    parser.add_argument('--all-as-head', help='all nodes are identical to head node',
                        action='store_true')
    parser.add_argument('--svm', help='enable nested virtualization', action='store_true')
    parser.add_argument('--default-user', help='change default user')
    parser.add_argument('-o', '--data-file', help='test environment datafile',
                        default='testenv.json')
    parser.add_argument('-a', '--application', help='add nodes to existing application')
    parser.add_argument('--cloud', choices=('AMAZON', 'GOOGLE', 'HPCLOUD'),
                        help='the cloud to deploy to')
    parser.add_argument('--region', help='the region to deploy to')
    parser.add_argument('--optimization', choices=('COST', 'PERFORMANCE'),
                        help='the optimization to use when publishing')
    parser.add_argument('--no-start', action='store_true', default=False,
                        help='do not start the nodes')
    parser.add_argument('--autostop', type=int, help='Autostop in minutes')
    parser.add_argument('count', type=int, help='the number of nodes to create')
    return parser


def setup_logger(args):
    """Setup the logger."""
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
    handler = logging.StreamHandler(sys.stdout)
    handler.setFormatter(logging.Formatter('%(levelname)s %(message)s'))
    logger.addHandler(handler)


def check_arguments(args):
    """Further checks on the command-line arguments.

    Raises a ``ValueError`` if there is an error.
    """
    if args.username is None:
        args.username = os.environ.get('RAVELLO_USERNAME')
    if args.username is None:
        raise ValueError('specify either --username or set $RAVELLO_USERNAME')
    if args.password is None:
        args.password = os.environ.get('RAVELLO_PASSWORD')
    if args.password is None:
        if not sys.stdin.isatty():
            raise ValueError('specify either --password, set $RAVELLO_PASSWORD,'
                             ' or run interactively')
        args.password = getpass('Please enter Ravello API password: ')
    if not 1 <= args.cpus <= 4:
        raise ValueError('--cpu must be between 1 and 4')
    if not 512 <= args.memory <= 16384:
        raise ValueError('--memory must be between 512 and 16384 MB')
    if not 4 <= args.disk <= 450:
        raise ValueError('--disk must be between 4 and 450 GB')
    if args.optimization == 'COST':
        if args.cloud:
            raise ValueError('cannot specify --cloud with COST optimization')
        if args.region:
            raise ValueError('cannot specify --region with COST optimization')
    if not args.optimization:
        args.optimization = 'PERFORMANCE' if args.cloud else 'COST'
    flavor = get_flavor(args)
    if args.keypair and not flavor.cloudinit:
        raise ValueError("flavor {0} does not support --keypair".format(flavor.name))
    if args.default_user and not flavor.cloudinit:
        raise ValueError("flavor {0} does not support --default-user".format(flavor.name))
    if args.network and not flavor.get_cc_network:
        raise ValueError("flavor {0} does not support --network".format(flavor.name))


def get_flavor(args):
    """Return the current flavor."""
    for flavor in flavors:
        if flavor.name == args.flavor:
            return flavor
    raise AssertionError


def mac_aton(s):
    """Convert a Mac address to an integer."""
    try:
        mac = list(map(lambda x: int(x, 16), s.split(':')))
        mac = reduce(lambda a,b: a+b, [mac[i] << (5-i)*8 for i in range(6)])
    except (ValueError, IndexError):
        raise ValueError('illegal Mac: {0}'.format(s))
    return mac

def mac_ntoa(i):
    """Convert an int to a mac address."""
    return ':'.join(map(lambda x: '%02x'%x, [(i >> (5-j)*8) & 0xff for j in range(6)]))


def inet_aton(s):
    """Convert a dotted-quad to an int."""
    try:
        addr = list(map(int, s.split('.')))
        addr = reduce(lambda a,b: a+b, [addr[i] << (3-i)*8 for i in range(4)])
    except (ValueError, IndexError):
        raise ValueError('illegal IP: {0}'.format(s))
    return addr

def inet_ntoa(i):
    """Convert an int to dotted quad."""
    return '.'.join(map(str, [(i >> (3-j)*8) & 0xff for j in range(4)]))


def parse_cidr(network):
    """Parse a address/bits CIDR type network address."""
    try:
        addr, bits = network.split('/')
        addr = inet_aton(addr)
        bits = int(bits)
    except ValueError:
        raise ValueError('illegal network: {0}'.format(network))
    netmask = (0xffffffff << (32 - bits)) & 0xffffffff
    network = addr & netmask
    return network, netmask


def get_interfaces(args):
    """Parse the --network argument."""
    interfaces = []
    networks = args.network or ['dhcp']
    for ix, network in enumerate(networks):
        interface = {'name': 'eth{0}'.format(ix)}
        if network == 'dhcp':
            interface['bootproto'] = 'dhcp'
        else:
            interface['bootproto'] = 'static'
            network, netmask = parse_cidr(network)
            interface['network'] = inet_ntoa(network)
            interface['netmask'] = inet_ntoa(netmask)
            if ix == 0:
                interface['gateway'] = inet_ntoa(network + 1)
                interface['dns'] = interface['gateway']
        interfaces.append(interface)
    return interfaces


def getservbyport(port):
    """Like socket.getservbyport() but return a descriptive string if the
    service is not found."""
    try:
        name = socket.getservbyport(port)
    except socket.error:
        name = 'port-{0}'.format(port)
    return name


def getservbyname(name):
    """Liek socket.getservbyname() but return a ValueError with a descriptive
    message if the ser4vice is not found."""
    try:
        port = socket.getservbyname(name)
    except socket.error:
        raise ValueError('unknown service: {0}'.format(name))
    return port


def get_services(args):
    """Parse the --service argument."""
    resolved = []
    services = args.service or []
    for service in services:
        if service.isdigit():
            port = int(service)
            name = getservbyport(port)
        else:
            name = service
            port = getservbyname(name)
        resolved.append({'name': name, 'portRange': int(port)})
    return resolved


def connect_api(args):
    """Connect to the Ravello API and return a connection."""
    client = RavelloClient()
    client.connect()
    client.login(args.username, args.password)
    return client


def new_name(prefix, existing):
    """Return a name that starts with *prefix* and is not in *existing*."""
    for i in range(len(existing)+1):
        name = '{0}-{1}'.format(prefix, i)
        if name not in existing:
            break
    existing.add(name)
    return name


def create_or_get_application(client, args):
    """Create a new application or get an existing one.

    The application is returned as a dictionary.
    """
    if args.application:
        apps = client.get_applications({'name': args.application})
        if not apps:
            raise ValueError('application {0} not found'.format(args.application))
        app = client.reload(apps[0])
    else:
        apps = client.get_applications()
        existing = set((app['name'] for app in apps))
        name = new_name('Cloud', existing)
        app = { 'name': name, 'description': 'Created by create-nodes.py' }
        app = client.create_application(app)
    print('Using application {0}.'.format(app['name']))
    return app


def get_disk_image(client, flavor, args):
    """Get the disk image.

    The image is determined by --flavor.
    """
    if not flavor.imagename:
        return
    images = client.request('GET', '/diskImages')
    for image in images:
        if fnmatch(image['name'], flavor.imagename):
            break
    else:
        raise ValueError('flavor {0}: need a disk image with name {1!r}'
                            .format(flavor.name, flavor.imagename))
    return image


def get_keypair(client, args):
    """Get the key pair.

    The keypair is specified by name using --keypair.
    """
    if not args.keypair:
        return
    keypairs = client.get_keypairs({'name': args.keypair})
    if len(keypairs) == 0:
        raise ValueError('keypair not found: {0}'.format(args.keypair))
    return keypairs[0]


def convert_size(size, newunit):
    """Convert a disk/memory size strucuture to a specific unit."""
    value, unit = size['value'], size['unit']
    sizekb = value * (1024 if unit == 'MB' else 1024*1024)
    divisor = 1 if newunit == 'KB' else 1024 if newunit == 'MB' else 1024*1024
    return sizekb // divisor


def create_vm_template(image, keypair, flavor, interfaces, args):
    """Create a VM template."""
    template = {'baseVmId': 0,
                'description': 'VM created by ravello-create-nodes'}
    if args.svm:
        template['cpuIds'] = magic_svm_cpuids
    if flavor.cloudinit:
        template['supportsCloudInit'] = True
    if keypair:
        template['keypairId'] = keypair['id']
    drives = template['hardDrives'] = []
    drive = { 'index': 1,
              'type': 'DISK',
              'name': 'root',
              'boot': True,
              'controller': args.disk_model,
              'size': { 'value': args.disk, 'unit': 'GB' } }
    drives.append(drive)
    if flavor.imagetype == 'CDROM':
        drive = { 'index': 2,
                  'type': 'CDROM',
                  'name': 'cdrom',
                  'controller': 'ide' }
        drives.append(drive)
        template['bootOrder'] = ['CDROM', 'DISK'],
    if image:
        drives[-1]['baseDiskImageId'] = image['id']
    template['memorySize'] = { 'value': args.memory, 'unit': 'MB' }
    template['numCpus'] = args.cpus
    template['os'] = 'linux_manuel'  # sic.
    connections = template['networkConnections'] = []
    for ix,iface in enumerate(interfaces):
        nic = { 'index': ix,
                'deviceType': args.nic_model,
                'useAutomaticMac': False }
        ipcfg = { 'hasPublicIp': False }
        if iface['bootproto'] == 'dhcp':
            ipcfg['autoIpConfig'] = {}
        else:
            ipcfg['staticIpConfig'] = {'mask': iface['netmask']}
            if 'gateway' in iface:
                ipcfg['staticIpConfig']['gateway'] = iface['gateway']
            if 'dns' in iface:
                ipcfg['staticIpConfig']['dns'] = iface['dns']
        connection = { 'name': iface['name'],
                       'device': nic,
                       'ipConfig': ipcfg }
        connections.append(connection)
    services = template['suppliedServices'] = []
    service = { 'name': 'ssh',
                'portRange': '22',
                'protocol': 'TCP',
                'external': True }
    services.append(service)
    return template


def get_cc_user(username):
    """Return a cloud-config section for changing the default user."""
    return {'system_info': {'default_user': {'name': username}}}


def get_cc_network_redhat(n, interfaces):
    """Return a cloud-config section for creating the static networking
    definitions for VM number *n* (Red Hat version)."""
    cloudcfg = {}
    files = cloudcfg['write_files'] = []
    for iface in interfaces:
        ifcfg = {'path': '/etc/sysconfig/network-scripts/ifcfg-{0[name]}'.format(iface),
                 'permissions': 493 }  # 0755
        content = {'NAME': iface['name'],
                   'BOOTPROTO': iface['bootproto'],
                   'HWADDR': iface['mac'],
                   'ONBOOT': 'yes'}
        for key in ('ipaddr', 'netmask', 'gateway'):
            if key in iface:
                content[key.upper()] = iface[key]
        if 'dns' in iface:
            content['DNS1'] = iface['dns']
        content = ''.join(['{0}={1}\n'.format(*item) for item in content.items()])
        ifcfg['content'] = content
        files.append(ifcfg)
    cloudcfg['runcmd'] = ['shutdown -r +1']
    return cloudcfg


def get_cc_network_debian(n, interfaces):
    """Return a cloud-config section for creating the static networking
    definitions for VM number *n* (Debian/Ubuntu version)."""
    cloudcfg = {}
    files = cloudcfg['write_files'] = []
    etcnetif = {'path': '/etc/network/interfaces',
                'permissions': 493 }  # 0755
    lines = []
    lines.append('auto {0}\n'.format(' '.join([iface['name'] for iface in interfaces])))
    for iface in interfaces:
        lines.append('iface {0[name]} inet {0[bootproto]}\n'.format(iface))
        if 'ipaddr' in iface:
            lines.append('  address {0[ipaddr]}\n'.format(iface))
        if 'netmask' in iface:
            lines.append('  netmask {0[netmask]}\n'.format(iface))
        if 'gateway' in iface:
            lines.append('  gateway {0[gateway]}\n'.format(iface))
        if 'dns' in iface:
            lines.append('  dns-nameservers {0[dns]}\n'.format(iface))
            lines.append('  dns-domain localdomain\n'.format(iface))
    etcnetif['content'] = ''.join(lines)
    files.append(etcnetif)
    cloudcfg['runcmd'] = ['shutdown -r +1']
    return cloudcfg


def merge_cc(a, b):
    """Merge cloud-config *b* into *a*."""
    for key in b:
        if key not in a:
            a[key] = b[key]
        elif isinstance(a[key], list):
            if not isinstance(b[key], list):
                raise ValueError('cannot merge list with non-list')
            a[key] += b[key]
        elif isinstance(a[key], dict):
            if not isinstance(b[key], list):
                raise ValueError('cannot merge dict with non-dict')
            merge_cc(a[key], b[key])
        else:
            a[key] = b[key]


# VM flavors

Flavor = namedtuple('Flavor', ('name', 'imagename', 'imagetype', 'cloudinit', 'get_cc_network'))

flavors = [
    Flavor('fedora', 'Fedora-x86_64-*-sda', 'DISK', True, get_cc_network_redhat),
    Flavor('pxe', 'ipxe', 'CDROM', False, True),
    Flavor('ubuntu', 'trusty-server-cloudimg-amd64-disk1', 'DISK', True, get_cc_network_debian),
    Flavor('empty', None, 'DISK', False, None),
]


def add_vms_to_application(app, template, flavor, interfaces, services, args):
    """Add VMs to an exististing applications and personalize them."""
    design = app.setdefault('design', {})
    vms = design.setdefault('vms', [])
    existing = set((vm['name'] for vm in vms))
    for i in range(args.count):
        vm = copy.deepcopy(template)
        vm['name'] = new_name('VM', existing)
        cloudcfg = {}
        if args.default_user:
            merge_cc(cloudcfg, get_cc_user(args.default_user))
        if i == 0 or args.all_as_head:
            for svc in services:
                service = {'name': svc['name'],
                           'portRange': svc['portRange'],
                           'protocol': 'TCP',
                           'external': True}
                vm['suppliedServices'].append(service)
        for ix,iface in enumerate(interfaces):
            if iface['bootproto'] == 'dhcp':
                continue
            iface['mac'] = mac_ntoa(mac_aton(mac_base) + (ix<<16) + i)
            iface['ipaddr'] = inet_ntoa(inet_aton(iface['network']) + 10 + i)
            conn = vm['networkConnections'][ix]
            conn['device']['mac'] = iface['mac']
            conn['ipConfig']['staticIpConfig']['ip'] = iface['ipaddr']
            if ix == 0:
                for svc in vm['suppliedServices']:
                    svc['ip'] = iface['ipaddr']
            if ix == 0 and (i == 0 or args.all_as_head):
                conn['ipConfig']['hasPublicIp'] = args.public_ip
        if args.network and callable(flavor.get_cc_network):
            merge_cc(cloudcfg, flavor.get_cc_network(i, interfaces))
        if cloudcfg:
            userdata = '#cloud-config\n'
            userdata += json.dumps(cloudcfg, indent=2)
            userdata += '\n'
            vm['configurationManagement'] = { 'userData': userdata }
        vms.append(vm)
    return app


def create_publish_request(args):
    request = { 'startAllVms': not args.no_start }
    if args.cloud:
        request['preferredCloud'] = args.cloud
    if args.region:
        request['preferredRegion'] = args.region
    if args.optimization:
        level = '{0}_OPTIMIZED'.format(args.optimization)
        request['optimizationLevel'] = level
    return request


def update_and_publish(client, app, args):
    """Publish a new application, or publish updates to an existing app."""
    client.update_application(app)
    if args.autostop is not None:
        timeout = -1 if not args.autostop else 60 * args.autostop
        body = { 'expirationFromNowSeconds': timeout }
        url = '/applications/{0}/setExpiration'
        client.request('POST', url.format(app['id']), body)
    published = app.get('published')
    if published:
        url = '/applications/{0}/publishUpdates'
        if args.no_start:
            url += '?startAllDraftVms=false'
        body = None
    else:
        url = '/applications/{0}/publish'
        body = create_publish_request(args)
    client.request('POST', url.format(app['id']), body)
    app = client.reload(app)
    print('Added {0} VMs to application.'.format(args.count))
    return app


def update_datafile(app, args):
    """Write node information to an environment description file.

    If the file exists, it is updated. Otherwise it is created.
    """
    try:
        with open(args.data_file, 'r') as fin:
            desc = json.load(fin)
    except (IOError, OSError) as e:
        if e.errno != errno.ENOENT:
            six.reraise(*sys.exc_info())
        desc = {}
    nodes = desc['nodes'] = []
    for vm in app['design']['vms']:
        drive = [drive for drive in vm['hardDrives'] if drive['type'] == 'DISK'][0]
        nic = vm['networkConnections'][0]['device']
        node = { 'cpu': vm['numCpus'],
                 'memory': convert_size(vm['memorySize'], 'KB'),
                 'disk': convert_size(drive['size'], 'GB'),
                 'arch': 'amd64',
                 'mac': nic['mac'],
                 'pm_type': 'pxe_ravello',
                 'pm_addr': '{0[id]}/{1[id]}'.format(app, vm) }
        nodes.append(node)
    with open(args.data_file, 'w') as fout:
        json.dump(desc, fout, indent=2, sort_keys=True)
    print('Updated {0.data_file} with {0.count} new node definitions.'.format(args))


# Main program entry point

def main():
    debug = False

    try:
        parser = create_parser()
        args = parser.parse_args()
        check_arguments(args)

        debug = args.debug

        setup_logger(args)

        flavor = get_flavor(args)
        interfaces = get_interfaces(args)
        services = get_services(args)

        client = connect_api(args)

        image = get_disk_image(client, flavor, args)
        keypair = get_keypair(client, args)

        app = create_or_get_application(client, args)
        template = create_vm_template(image, keypair, flavor, interfaces, args)
        app = add_vms_to_application(app, template, flavor, interfaces, services, args)

        app = update_and_publish(client, app, args)

        update_datafile(app, args)

    except Exception as e:
        if debug:
            raise six.reraise(*sys.exc_info())
        sys.stdout.write('Error: {0!s}\n'.format(e))
        sys.exit(1)


if __name__ == '__main__':
    main()
