#!/usr/bin/env python

# Copyright (c) 2009, Purdue University
# All rights reserved.
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 
# Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# Redistributions in binary form must reproduce the above copyright notice, this
# list of conditions and the following disclaimer in the documentation and/or
# other materials provided with the distribution.
# 
# Neither the name of the Purdue University nor the names of its contributors
# may be used to endorse or promote products derived from this software without
# specific prior written permission.
# 
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""Update host tool for dns management"""


__copyright__ = 'Copyright (C) 2009, Purdue University'
__license__ = 'BSD'
__version__ = '0.17'


import os
import sys
import re
import getpass

from optparse import OptionParser

from roster_user_tools import cli_record_lib
from roster_user_tools import cli_common_lib
from roster_user_tools import roster_client_lib
from roster_user_tools.action_flags import Update
from roster_user_tools.data_flags import Hosts


class Args(Update, Hosts):
  pass


def MakeHostsFile(options, cli_common_lib_instance):
  """Makes a hosts file string

  Inputs:
    options: options object from optparse

  Outputs:
    string: string of hosts file
  """
  records_dict = roster_client_lib.RunFunction(
      'ListRecordsByCIDRBlock', options.username, credfile=options.credfile,
      server_name=options.server, args=[options.range],
      kwargs={'view_name': options.view_name})['core_return']
  if( records_dict == {} ):
    cli_common_lib.DnsError('No records found.', 1)
  ip_address_list = roster_client_lib.RunFunction(
      'CIDRExpand', options.username, credfile=options.credfile,
      server_name=options.server, args=[options.range])['core_return']
  view_dependency = options.view_name
  if( options.view_name != 'any' and options.view_name != None ):
    view_dependency = '%s_dep' % options.view_name
  file_contents = ('#:range:%s\n'
                   '#:view_dependency:%s\n'
                   '# Do not delete any lines in this file!\n'
                   '# To remove a host, comment it out, to add a host,\n'
                   '# uncomment the desired ip address and specify a\n'
                   '# hostname. To change a hostname, edit the hostname\n'
                   '# next to the desired ip address.\n'
                   '#\n'
                   '# The "@" symbol in the host column signifies inheritance\n'
                   '# of the origin of the zone, this is just shorthand.\n'
                   '# For example, @.university.edu. would be the same as\n'
                   '# university.edu.\n'
                   '#\n'
                   '# Columns are arranged as so:\n'
                   '# Ip_Address Fully_Qualified_Domain Hostname\n%s' % (
                       options.range, view_dependency,
                       cli_common_lib.PrintHosts(
                           records_dict, ip_address_list, options.view_name)))
  return file_contents

def CheckIPV4(ip_address):
  """Checks if IP address is valid IPV4

  Inputs:
    ip_address: string of ip address

  Outputs:
    boolean: whether or not ip address is valid
  """
  if( ip_address.startswith(':range:') ):
    return False
  ip_regex = re.search(r"\b(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\."
                       r"(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\."
                       r"(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\."
                       r"(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
                       ip_address)
  if( ip_regex is None or len(ip_address.split('.')) != 4 ):
    return False
  else:
    return True

def CheckIPV6(ip_address):
  """Checks if IP address is valid IPV6

  Inputs:
    ip_address: string of ip address

  Outputs:
    boolean: whether or not ip address is valid
  """
  # we will build an array of matches and then join them
  matches = []
  simplest_match = r'[0-9a-f]{1,4}'
  for i in range(1,7):
    matches.append(r'\A(%s:){1,%d}(:%s){1,%d}\Z' % (simplest_match, i,
                                                    simplest_match, 7-i))
  matches.append(r'\A((%s:){1,7}|:):\Z' % simplest_match)
  matches.append(r'\A:(:%s){1,7}\Z' % simplest_match)
  matches.append(r'\A((([0-9a-f]{1,4}:){6})(25[0-5]|2[0-4]\d|[0-1]?\d?\d)'
                 r'(\.(25[0-5]|2[0-4]\d|[0-1]?\d?\d)){3})\Z')
  # support for embedded ipv4 addresses in the lower 32 bits
  ipv4 = r'(25[0-5]|2[0-4]\d|[0-1]?\d?\d)(\.(25[0-5]|2[0-4]\d|[0-1]?\d?\d)){3}'
  matches.append(r'\A((%s:){5}%s:%s)\Z' % (simplest_match, simplest_match,
                                           ipv4))
  matches.append(r'\A(%s:){5}:%s:%s\Z' % (simplest_match, simplest_match, ipv4))
  for i in range(1,5):
    matches.append(r'\A(%s:){1,%d}(:%s){1,%d}:%s\Z' % (
        simplest_match, i, simplest_match, 5-i, ipv4))
  matches.append(r'\A((%s:){1,5}|:):%s\Z' % (simplest_match, ipv4))
  matches.append(r'\A:(:%s){1,5}:%s\Z' % (simplest_match, ipv4))
  bigre = "("+")|(".join(matches)+")"
  bigre = re.compile(bigre, re.I)
  return bigre.search(ip_address) and True

def ReadHostsFile(options, hosts_file_contents, cli_common_lib_instance):
  """Reads a hosts file

  Inputs:
    hosts_file_contents: string of contents of hosts file

  Outputs:
    dict: dictionary of hosts file ex:
          {'192.168.1.4/30': {'192.168.1.1': 'host1.university.edu'
                              '192.168.1.2': None}
  """
  range_line = None
  view_dependency_line = None
  hosts_file_lines = hosts_file_contents.split('\n')
  for line_number, line in enumerate(hosts_file_lines):
    if( line.startswith('#:range:') ):
      range_line = line_number
    if( line.startswith('#:view_dependency:') ):
      view_dependency_line = line_number
    if( range_line and view_dependency_line ):
      break
  range = hosts_file_lines[range_line].split('#:range:', 1)[1].lstrip()
  options.view_name = hosts_file_lines[view_dependency_line].split(
      '#:view_dependency:', 1)[1].strip().rsplit('_dep', 1)[0]
  ip_address_list = roster_client_lib.RunFunction('CIDRExpand',
                                                  options.username,
                                                  credfile=options.credfile,
                                                  server_name=options.server,
                                                  args=[range])[
                                                      'core_return']
  records_dictionary = roster_client_lib.RunFunction(
      'ListRecordsByCIDRBlock', options.username, credfile=options.credfile,
      server_name=options.server, args=[range],
      kwargs={'view_name': options.view_name})['core_return']
  sorted_records = cli_common_lib.SortRecordsDict(
      records_dictionary, options.view_name)
  hosts_dict = {}
  ip_list_index = -1
  ip_address = None
  for line in hosts_file_contents.split('\n'):
    line = line.strip()
    if( line == '' ):
      continue
    last_ip_address = ip_address
    try:
      if( line.startswith('#') ):
        ip_address = line.split('#')[1].split()[0].strip()
      else:
        ip_address = line.split()[0].strip()
      if( CheckIPV4(ip_address) or CheckIPV6(ip_address)):
        if( ip_address != last_ip_address and last_ip_address is not None ):
          ip_list_index = ip_list_index + 1
    except IndexError:
      continue
    ip_address = line.lstrip('#').split()[0]
    if( line.startswith('#') and ip_address in sorted_records
          and len(sorted_records[ip_address]['forward']) > 0
          and len(sorted_records[ip_address]['reverse']) > 0 ):
      line_array = line.lstrip('#').split()
      if( not CheckIPV4(line_array[0]) and not CheckIPV6(line_array[0]) ):
        cli_common_lib.DnsError('Invalid ip address "%s" in file "%s"' % (
            ip_address, options.file), 1)
      else:
        if( ip_address != ip_address_list[ip_list_index] and 
            ip_address != last_ip_address ):
          cli_common_lib.DnsError(
                'IP Address %s is out of order.' % ip_address, 1)
        if( ip_address not in hosts_dict ):
          hosts_dict[line_array[0]] = []
      hosts_dict[ip_address].append({'host': None,
                                     'alias': None})
    elif( line.startswith('#') ):
      try:
        record_line = line.split('#')[1]
        host = None
        alias = None
        if( CheckIPV4(ip_address) or CheckIPV6(ip_address)):
          if( ip_address != ip_address_list[ip_list_index] and
              ip_address != last_ip_address ):
            cli_common_lib.DnsError(
                'IP Address %s is out of order.' % ip_address, 1)
        else:
          continue
      except IndexError:
        # Pass over empty comments
        continue
    else:
      line_array = line.split('#')[0].split()
      if( len(line_array) != 3 ):
        cli_common_lib.DnsError(
            'Line "%s" is incorrectly formatted in "%s"' % (
                line, options.file), 1)
      if( not CheckIPV4(line_array[0]) and not CheckIPV6(line_array[0]) ):
        cli_common_lib.DnsError('Invalid ip address "%s" in file "%s"' % (
            ip_address, options.file), 1)
      else:
        if( ip_address != ip_address_list[ip_list_index] and 
            ip_address != last_ip_address ):
          cli_common_lib.DnsError(
                'IP Address %s is out of order.' % ip_address, 1)
        if( ip_address not in hosts_dict ):
          hosts_dict[line_array[0]] = []
      hosts_dict[ip_address].append({'host': line_array[1],
                                     'alias': line_array[2]})
  return {'range': range, 'hosts':  hosts_dict}

def main(args):
  """Collects command line arguments, checks ip addresses and adds records.

  Inputs:
    args: list of arguments from the command line
  """
  command = None
  if( args and not args[0].startswith('-') ):
    command = args.pop(0)
  usage = ('\n'
           '\n'
           'To dump a text hosts file of a cidr-block:\n'
           '\t%s dump -r <cidr-block> [-f <file-name>] [-v <view-name>]\n'
           '\t[-z <zone-name>]\n'
           '\n'
           'To update a cidr block at once:\n'
           '\t%s update [-f <file-name>] [-v <view-name]\n'
           '\t[-z <zone-name>] [--commit|--no-commit]\n'
           '\n'
           'To dump and update a cidr block:\n'
           '\t%s edit -r <cidr_block> [-f <file-name>] [-v <view-name>]\n'
           '\t[-z <zone-name>] [--commit|--no-commit]\n' % tuple(
             [sys.argv[0] for x in range(3)]))
  args_instance = Args(command,
      ['update', 'dump', 'edit'], args, usage)
  options = args_instance.options

  try:
    cli_common_lib_instance = cli_common_lib.CliCommonLib(options)
  except cli_common_lib.ArgumentError, e:
    print 'ERROR: %s' % e
    sys.exit(1)
  cli_record_lib_instance = cli_record_lib.CliRecordLib(cli_common_lib_instance)
  write_to_db = False
  remove_file = False

  if( command in ['dump', 'edit'] ):
    file_contents = MakeHostsFile(options, cli_common_lib_instance)
    handle = open(options.file, 'w')
    try:
      handle.writelines(file_contents)
      handle.flush()
    finally:
      handle.close()

  if( command == 'edit' ):
    remove_file = True
    return_code = cli_common_lib.EditFile(options.file)
    if( return_code == 0 ):
      write_to_db = True

  if( write_to_db or command == 'update' ):
    new_hosts_file_handle = open(options.file, 'r')
    try:
      new_hosts_file_contents = new_hosts_file_handle.read()
    finally:
      new_hosts_file_handle.close()
    new_hosts_dict = ReadHostsFile(options, new_hosts_file_contents,
                                   cli_common_lib_instance)
    options.range = new_hosts_dict['range']
    new_hosts_dict = new_hosts_dict['hosts']
    original_hosts_dict = {}
    original_hosts_file_contents = MakeHostsFile(options,
                                                 cli_common_lib_instance)
    original_hosts_dict = ReadHostsFile(options, original_hosts_file_contents,
                                        cli_common_lib_instance)['hosts']
    zone_info = roster_client_lib.RunFunction(
        'ListZones', options.username, credfile=options.credfile,
        server_name=options.server,
        kwargs={'view_name': options.view_name})['core_return']
    zone_origin_dict = {}
    for zone in zone_info:
      if( options.view_name in zone_info[zone] ):
        if( zone_info[zone][options.view_name]['zone_origin'] not in
                zone_origin_dict ):
          zone_origin_dict[zone_info[zone][options.view_name][
              'zone_origin']] = zone
    update_hosts_dict = {}
    for ip_address in new_hosts_dict:
      try:
        if( ip_address not in original_hosts_dict or
            original_hosts_dict[ip_address] != new_hosts_dict[ip_address] ):
          update_hosts_dict[ip_address] = new_hosts_dict[ip_address]
      except KeyError:
        cli_common_lib.DnsError('IP Address %s is not in CIDR block %s.' % (
            ip_address, options.range), 1)
    delete_list = []
    add_list = []
    records_dictionary = roster_client_lib.RunFunction(
        'ListRecordsByCIDRBlock', options.username, credfile=options.credfile,
        server_name=options.server, args=[options.range],
        kwargs={'view_name': options.view_name})['core_return']
    sorted_records = cli_common_lib.SortRecordsDict(
        records_dictionary, options.view_name)
    for host in update_hosts_dict:
      added_ptr = False
      for record_number, record in enumerate(update_hosts_dict[host]):
        # Check if new host and alias match
        if( record['host'] is not None and record['alias'] is not None
            and not record['host'].startswith(record['alias'])
            and not record['alias'].endswith(u'@') ):
          cli_common_lib.DnsError(
              'Fully qualified domain name "%s" and alias "%s" do not match.' % (
                  record['host'],
                  record['alias']), 1)
        reverse_ip = roster_client_lib.RunFunction(
            'ReverseIP', options.username, credfile=options.credfile,
            server_name=options.server, args=[host])[
                'core_return']
        reverse_zone_name = roster_client_lib.RunFunction(
            'ListZoneByIPAddress', options.username, credfile=options.credfile,
            server_name=options.server, args=[host])[
                'core_return']
        # Only checking type of IP here, not validity
        if( CheckIPV4(host) ):
          record_type = u'a'
        else:
          record_type = u'aaaa'
        try:
          for record_enum in enumerate(records_dictionary[options.view_name][host]):
            if record_enum[1]['forward'] == True:
              records_id = record_enum[1]['records_id']
              record_ttl = record_enum[1]['record_ttl']
              record_last_user = record_enum[1]['record_last_user']
            elif record_enum[1]['forward'] == False:
              ptr_id = record_enum[1]['records_id']
        except KeyError:
          pass
        if( record['host'] is None ):
          print 'Host: %s with ip address %s will be REMOVED' % (
              original_hosts_dict[host][0]['host'], host)
          zone_origin = '%s.' % original_hosts_dict[host][0][
              'host'].split('%s.' % (original_hosts_dict[host][0][
                  'alias']), 1)[1]
          delete_list.append({'record_type': record_type,
                              'records_id': records_id,
                              'record_ttl': record_ttl,
                              'record_last_user': record_last_user,
                              'record_target': original_hosts_dict[
                                  host][record_number]['alias'],
                              'record_view_dependency': options.view_name,
                              'record_zone_name': zone_origin_dict[zone_origin],
                              'record_arguments': {'assignment_ip': host}})
          delete_list.append({'record_type': u'ptr',
                              'records_id': ptr_id,
                              'record_ttl': record_ttl,
                              'record_last_user': record_last_user,
                              'record_target': reverse_ip[:-len(zone_info[
                                  reverse_zone_name][options.view_name][
                                      'zone_origin']) - 1:],
                              'record_view_dependency': options.view_name,
                              'record_zone_name': reverse_zone_name,
                              'record_arguments': {'assignment_host':
                                  '%s.' % original_hosts_dict[
                                      host][0]['host']}})
          break
        else:
          if( host in original_hosts_dict
              and len(original_hosts_dict[host]) <= record_number ):
            pass
          elif( host in original_hosts_dict and
                len(sorted_records[host]['reverse']) == 1 and
                len(sorted_records[host]['forward']) >= 1 ):
            print 'Host: %s with ip address %s will be REMOVED' % (
                original_hosts_dict[host][record_number]['host'], host)
            try:
              zone_origin = '%s.' % original_hosts_dict[host][record_number][
                  'host'].split('%s.' % (
                      original_hosts_dict[host][record_number]['alias']), 1)[1]
            except:
              print original_hosts_dict[host]
            delete_list.append({'record_type': record_type,
                                'records_id': records_id,
                                'record_ttl': record_ttl,
                                'record_last_user': record_last_user,
                                'record_target': original_hosts_dict[
                                    host][record_number]['alias'],
                                'record_view_dependency': options.view_name,
                                'record_zone_name':
                                    zone_origin_dict[zone_origin],
                                'record_arguments': {'assignment_ip': host}})
            delete_list.append({'record_type': u'ptr',
                                'records_id': records_id,
                                'record_ttl': record_ttl,
                                'record_last_user': record_last_user,
                                'record_target': reverse_ip[:-len(zone_info[
                                    reverse_zone_name][options.view_name][
                                        'zone_origin']) - 1:],
                                'record_view_dependency': options.view_name,
                                'record_zone_name': reverse_zone_name,
                                'record_arguments': {'assignment_host':
                                    '%s.' % original_hosts_dict[host][
                                        record_number]['host']}})
          else:
            if( host in sorted_records and
                (len(sorted_records[host]['forward']) > 0 or
                 len(sorted_records[host]['reverse']) > 0) ):
                cli_common_lib.DnsError(
                    'Host with IP %s cannot be uncommented' % host, 1)

          print 'Host: %s with ip address %s will be ADDED' % (
              record['host'], host)
          if( record['alias'] == u'@' ):
            zone_origin = '%s.' % record['host']
          elif( record['alias'].endswith(u'@') ):
            try:
              prepend = record['alias'].split('@')[0]
              zone_origin = '%s.' % record['host'].split(prepend)[1]
            except:
              raise Exception('%s %s %s' % (record['host'], record['alias'], prepend))
          else:
            try:
              zone_origin = '%s.' % record['host'].split(
                  '%s.' % (record['alias']), 1)[1]
            except:
              raise Exception('%s %s' % (record['host'], record['alias']))
          add_list.append({'record_type': record_type,
                           'record_target': record['alias'],
                           'record_view_dependency': options.view_name,
                           'record_zone_name': zone_origin_dict[zone_origin],
                           'record_arguments': {'assignment_ip': host}})
          if( not added_ptr ):
            add_list.append({'record_type': u'ptr',
                             'record_target': reverse_ip[:-len(zone_info[
                                 reverse_zone_name][options.view_name][
                                     'zone_origin']) - 1:],
                             'record_view_dependency': options.view_name,
                             'record_zone_name': reverse_zone_name,
                             'record_arguments': {'assignment_host':
                                 '%s.' % record['host']}})
            added_ptr = True
          break
    while( not options.commit and not options.no_commit ):
      yes_no = raw_input('Do you want to commit these changes? (y/N): ')
      if( yes_no.lower() not in ['y', 'yes', 'n', 'no', ''] ):
        continue
      if( yes_no.lower() in ['n', 'no', ''] ):
        options.no_commit = True
      if( yes_no.lower() in ['y', 'yes'] ):
        options.commit = True
    if( options.no_commit ):
      print 'No changes made.'
      sys.exit(0)
    rows = roster_client_lib.RunFunction(
        'ProcessRecordsBatch', options.username, credfile=options.credfile,
        server_name=options.server,
        kwargs={'delete_records': delete_list,
                'add_records': add_list})['core_return']
    if( remove_file and not options.keep_output ):
      os.remove(options.file)

  elif( command in ['dump', 'edit'] ):
    pass

  else:
    cli_common_lib.DnsError(
        'Command %s exists, but codepath does not.', 1)

if __name__ == "__main__":
    main(sys.argv[1:])
