#!/usr/bin/env python2
# -*- coding: utf-8 -*-

"""
ssh-ldap-pubkey - Utility to manage SSH public keys stored in LDAP.

Usage:
  ssh-ldap-pubkey list [options]
  ssh-ldap-pubkey add [options] FILE
  ssh-ldap-pubkey del [options] PATTERN
  ssh-ldap-pubkey --help

  -                      Read public key from stdin.
  FILE                   Path to the public key file to add.
  PATTERN                Pattern that specifies public key(s) to delete, i.e.
                         a complete key or just a part of it.

Options:
  -b DN --base=DN        Base DN where to search for the users' entry. If not
                         provided, then it's read from the config file.
  -c FILE --conf=FILE    Path of the ldap.conf (default is /etc/ldap.conf).
                         The ldap.conf is not required when at least --base is
                         provided.
  -h URI --uri=URI       URI of the LDAP server to connect; loaded from the
                         config file by default. If not defined even there,
                         then it defaults to ldap://localhost.
  -q --quiet             Be quite.
  -u LOGIN --user=LOGIN  Login of the user to bind as and change public key(s)
                         (default is the current user).
  -v --version           Show version information.
  --help                 Show this message.
"""

from __future__ import print_function

import base64
import ldap
import re
import struct
import sys

from docopt import docopt
from getpass import getpass, getuser
from os import access, R_OK

__version__ = '0.2.3'


#######################  Constants  #######################

DEFAULT_CONFIG_PATH = '/etc/ldap.conf'
DEFAULT_HOST = 'localhost'
DEFAULT_PORT = 389
DEFAULT_TIMEOUT = 10
DEFAULT_LOGIN_ATTR = 'uid'
DEFAULT_FILTER = 'objectclass=posixAccount'
DEFAULT_SCOPE = 'sub'

LDAP_PUBKEY_CLASS = 'ldapPublicKey'
LDAP_PUBKEY_ATTR = 'sshPublicKey'


########################  Classes  ########################

class LdapSSH(object):

    def __init__(self, conf):
        self.conf = conf
        self._conn = None

    def connect(self):
        conf = self.conf

        if not conf.uri or not conf.base:
            raise ConfigError("Base DN and LDAP URI must be provided.", 1)

        if conf.cacert_dir:
            # this is a global option!
            ldap.set_option(ldap.OPT_X_TLS_CACERTDIR, conf.cacert_dir)

        self._conn = conn = ldap.initialize(conf.uri)
        try:
            conn.protocol_version = conf.ldap_version
            conn.network_timeout = conf.bind_timeout
            conn.timeout = conf.search_timeout

            if conf.bind_dn and conf.bind_pass:
                self._bind(conf.bind_dn, conf.bind_pass)
        except ldap.SERVER_DOWN:
            raise LDAPConnectionError("Can't contact LDAP server.", 3)

    def close(self):
        self._conn and self._conn.unbind_s()

    def add_pubkey(self, login, password, pubkey):
        if not is_valid_openssh_pubkey(pubkey):
            raise InvalidPubKeyError("Invalid key, not in OpenSSH Public Key format.", 1)

        dn = self.find_dn_by_login(login)
        self._bind(dn, password)

        if self._has_pubkey(dn, pubkey):
            raise PubKeyAlreadyExistsError(
                "Public key %s already exists." % keyname(pubkey), 1)

        modlist = [(ldap.MOD_ADD, LDAP_PUBKEY_ATTR, pubkey)]
        try:
            self._conn.modify_s(dn, modlist)

        except ldap.OBJECT_CLASS_VIOLATION:
            modlist += [(ldap.MOD_ADD, 'objectClass', LDAP_PUBKEY_CLASS)]
            self._conn.modify_s(dn, modlist)

        except ldap.UNDEFINED_TYPE:
            raise ConfigError(
                "LDAP server doesn't define schema for attribute: %s" % LDAP_PUBKEY_ATTR, 1)

    def find_and_remove_pubkeys(self, login, password, pattern):
        dn = self.find_dn_by_login(login)
        self._bind(dn, password)

        pubkeys = [key for key in self._find_pubkeys(dn) if pattern in key]
        for key in pubkeys:
            self._remove_pubkey(dn, key)

        return pubkeys

    def find_pubkeys(self, login):
        return self._find_pubkeys(self.find_dn_by_login(login))

    def find_dn_by_login(self, login):
        conf = self.conf
        filter_s = "(&(%s)(%s=%s))" % (conf.filter, conf.login_attr, login)

        result = self._conn.search_s(conf.base, conf.scope, filter_s, ['dn'])
        if not result:
            raise UserEntryNotFoundError("No user with login '%s' found." % login, 2)

        return result[0][0]

    def _bind(self, dn, password):
        try:
            self._conn.simple_bind_s(dn, password)
        except ldap.INVALID_CREDENTIALS:
            raise InvalidCredentialsError("Invalid credentials for %s." % dn, 2)

    def _find_pubkeys(self, dn):
        result = self._conn.search_s(
            dn, ldap.SCOPE_BASE, attrlist=[LDAP_PUBKEY_ATTR])

        return result[0][1].get(LDAP_PUBKEY_ATTR, [])

    def _has_pubkey(self, dn, pubkey):
        current = self._find_pubkeys(dn)
        is_same_key = lambda k1, k2: k1.split()[1] == k2.split()[1]

        return any(key for key in current if is_same_key(key, pubkey))

    def _remove_pubkey(self, dn, pubkey):
        modlist = [(ldap.MOD_DELETE, LDAP_PUBKEY_ATTR, pubkey)]
        try:
            self._conn.modify_s(dn, modlist)

        except ldap.OBJECT_CLASS_VIOLATION:
            modlist += [(ldap.MOD_DELETE, 'objectClass', LDAP_PUBKEY_CLASS)]
            self._conn.modify_s(dn, modlist)

        except ldap.NO_SUCH_ATTRIBUTE:
            raise NoPubKeyFoundError("No such public key exists: %s." % keyname(pubkey), 1)


class LdapConfig(object):

    def __init__(self, path, quiet=False):
        conf = {}
        if access(path, R_OK):
            conf = self._parse_file(path)
        elif not quiet:
            print("Notice: Could not read config: %s; running with defaults." % path)

        if 'uri' in conf:
            self.uri = conf['uri'].split()[0]  # use just first address for now
        else:
            host = conf.get('host', DEFAULT_HOST)
            port = conf.get('port', DEFAULT_PORT)
            self.uri = "ldap://%s:%s" % (host, port)

        self.base = conf.get('nss_base_passwd', '').split('?')[0] or conf.get('base', None)
        self.bind_dn = conf.get('binddn', None)
        self.bind_pass = conf.get('bindpw', None)
        self.ldap_version = int(conf.get('ldap_version', ldap.VERSION3))
        self.bind_timeout = int(conf.get('bind_timelimit', DEFAULT_TIMEOUT))
        self.search_timeout = int(conf.get('timelimit', DEFAULT_TIMEOUT))
        self.login_attr = conf.get('pam_login_attribute', DEFAULT_LOGIN_ATTR)
        self.filter = conf.get('pam_filter', DEFAULT_FILTER)
        self.cacert_dir = conf.get('tls_cacertdir', None)

        self.scope = {
            'base': ldap.SCOPE_BASE,
            'one': ldap.SCOPE_ONELEVEL,
            'sub': ldap.SCOPE_SUBTREE
        }[conf.get('scope', DEFAULT_SCOPE)]

    def __str__(self):
        return str(self.__dict__)

    def _parse_file(self, path):
        conf = {}
        with open(path, 'r') as f:
            for line in f:
                m = re.match(r'^(\w+)\s+([^#]+\b)', line)
                if m: conf[m.group(1).lower()] = m.group(2)

        return conf


class Error(Exception):

    def __init__(self, msg, code=1):
        self.msg = msg
        self.code = code

    def __str__(self):
        return self.msg


class ConfigError(Error): pass
class InvalidCredentialsError(Error): pass
class InvalidPubKeyError(Error): pass
class LDAPConnectionError(Error): pass
class NoPubKeyFoundError(Error): pass
class PubKeyAlreadyExistsError(Error): pass
class UserEntryNotFoundError(Error): pass


#######################  Functions  #######################

def read_file(path):
    """ Read file and return its content with stripped newlines.

    This is foolproof against some users that are unable to copy keys properly.
    """
    with open(path, 'r') as f:
        return "".join(f.readlines()).strip()


def read_stdin():
    """ Read from standard input and strip newlines.

    This is foolproof against some users that are unable to copy keys properly.
    """
    return "".join(sys.stdin.readlines()).strip()


def keyname(pubkey):
    return pubkey.split()[-1]


def is_valid_openssh_pubkey(pubkey):
    """ Validation based on http://stackoverflow.com/a/2494645/2217862. """
    if not pubkey and len(pubkey.split()) < 2:
        return False

    key_type, data64 = pubkey.split()[0:2]
    try:
        data = base64.decodestring(data64)
    except base64.binascii.Error:
        return False

    int_len = 4
    str_len = struct.unpack('>I', data[:int_len])[0]

    if data[int_len:(int_len + str_len)] != key_type:
        return False

    return True


def halt(msg, code=1):
    print(msg, file=sys.stderr)
    exit(code)


##########################  Main  #########################

def main(**kwargs):
    confpath = kwargs['--conf'] or DEFAULT_CONFIG_PATH
    quiet = kwargs['--quiet']
    login = kwargs['--user'] or getuser()

    conf = LdapConfig(confpath, quiet)
    if kwargs['--uri']:
        conf.uri = kwargs['--uri']
    if kwargs['--base']:
        conf.base = kwargs['--base']

    # prompt for password
    if kwargs['add'] or kwargs['del']:
        passw = getpass("Enter login (LDAP) password for user '%s': " % login)

    info = lambda msg: quiet or print(msg)
    ls = LdapSSH(conf)

    try:
        ls.connect()

        if kwargs['add']:
            filesrc = kwargs['FILE'] and kwargs['FILE'] != '-'
            pubkey = read_file(kwargs['FILE']) if filesrc else read_stdin()

            ls.add_pubkey(login, passw, pubkey)
            info("Key has been stored: %s" % keyname(pubkey))

        elif kwargs['del']:
            keys = ls.find_and_remove_pubkeys(login, passw, kwargs['PATTERN'])
            if keys:
                info("Deleted keys: \n%s" % '\n'.join(keys))
            else:
                info("No keys found to delete.")

        else:  # list
            keys = ls.find_pubkeys(login)
            if keys:
                print('\n'.join(keys))
            else:
                info("No public keys found.")

    except Error as e:
        halt("Error: %s" % e, e.code)

    except IOError as e:
        halt("Error: %s: %s" % (e.strerror, e.filename), 1)

    finally:
        ls.close()


kwargs = docopt(__doc__, version="ssh-ldap-pubkey %s" % __version__)

if __name__ == '__main__':
    main(**kwargs)
