#!/usr/bin/python
'''
Module with shared caching object used by GenericLdapdict objects
caches connection data and GenericLdapdict objects:
    LdapCache,
    LdapCacheWrite,
    ADSLdapCache,
    ADSLdapCacheWrite
'''

import sys
import os
import ldap
import ldap.sasl
import ldap.schema
import ldap.modlist
import ldapurl
from pprint import pprint as pp
from types import StringType, UnicodeType, ListType, NoneType

__all__ = ['LdapCache', 'LdapCacheWrite','ADSLdapCache', 'ADSLdapCacheWrite', 'runtests']
__version__ = '0.1.0'

# ads = False

class LdapCache(object):
    '''
    Shared object used by GenericLdapdict objects
    caches connection data and GenericLdapdict objects
    '''
    MIN_ATTRLIST = ['objectClass', 'cn', 'ou', 'namingContexts']
    schema = None

    def __init__(self, ldapclass, url = None, user = '', passwd = '', cacert = None, mycert = None, mykey = None, debug = False):
        '''
        __init__ LdapCache
        input:
            ldapclass    : GenericLdapDict class object
            url          : ldapurl; eg. ldap://localhost:636
            user         : user to bind with to Directory
            passwd       : password
            cacert       : path to CA certificate file or directory for CA certificates
            mycert       : path to user certificate
            mykey        : path to user private key

        if  mycert and mykey are provided sasl/EXTERNAL bind will be attempted and user and password are ignored

        rootDSE and defaultNamingcontext (base DN) can be accessed through python attributes:
            root
            bdn
        '''
        if url is None:
            url = 'ldap://localhost:389/'

        if url[-1] != '/':
            url = url + '/'

        self.dnidx = {}
        self.__con = None
        self.__url = url
        self.__user = user
        self.__passwd = passwd
        self.__cacert = cacert
        self.__mycert = mycert
        self.__mykey = mykey
        self.__debug = debug
        self.schema = None
        self.classdef = {}
        self.__ldapclass = ldapclass

        self.connect()

        if user or mykey:
            self.schema = self.getschema()

        # get root DSE
        self.root = self.__ldapclass('', self)
        self.root.getattr(['+'])
        self.root.parents = [None]

        # get namingContexts
        if user or mykey:
            nctx = []
            for dn in self.root['namingContexts']:
                nctx.append(self.__ldapclass(dn, self))

            if self.root.has_key('defaultNamingContext'):
                self.bdn = self.__ldapclass(self.root['defaultNamingContext'][0], self)
            else:
                self.bdn = nctx[0]
            self.root.children = nctx

    def __str__(self):
        '''
        Displays url of LdapCache object
        '''
        return self.__url

#    def __repr__(self):
#        return self.__url

    def _getcon(self):
        return self.__con

    def disconnect(self):
        try:
            self.__con.unbind_s()
        except:
            return False
        else:
            return True

    def connect(self):
        if self.__debug:
            ldap.set_option(ldap.OPT_DEBUG_LEVEL,255)
            ldapmodule_trace_level = 1
            ldapmodule_trace_file = sys.stderr
        else:
            ldapmodule_trace_level = 0
            ldapmodule_trace_file = sys.stderr

        if self.__cacert is not None:
            if os.path.isfile(self.__cacert):
                ldap.set_option(ldap.OPT_X_TLS_CACERTFILE, self.__cacert)
            elif os.path.isdir(self.__cacert):
                ldap.set_option(ldap.OPT_X_TLS_CACERTDIR, self.__cacert)
            else:
                raise IOError("No valid certificate found: '%s'." % (self.__cacert))

        if self.__mycert is not None and self.__mykey is not None:
            ldap.set_option(ldap.OPT_X_TLS_CERTFILE, self.__mycert)
            ldap.set_option(ldap.OPT_X_TLS_KEYFILE, self.__mykey)

        # never follow referrals
        ldap.set_option(ldap.OPT_REFERRALS, ldap.DEREF_NEVER)

        try:
            self.__con = ldap.initialize(self.__url, trace_level=ldapmodule_trace_level, trace_file=ldapmodule_trace_file)
            if self.__cacert is not None:
                self.__con.set_option(ldap.OPT_X_TLS, ldap.OPT_X_TLS_DEMAND)
                self.__con.start_tls_s()
            if self.__mycert is not None and self.__mykey is not None:
                self.__con.sasl_interactive_bind_s("", ldap.sasl.external())
                # pp("Started SASL")
            else:
                self.__con.simple_bind_s(self.__user, self.__passwd)
            self.__con.search_s('', ldap.SCOPE_BASE)
        except ldap.LDAPError:
            raise

        return self.__con
    
    def whoami(self):
        '''
        Returns identity when binding with sasl
        '''
        dn = self.__con.whoami_s().split(':', 1)[1]
        return self.__ldapclass(dn, self)

    # @classmethod
    def cmpdn(cls, dn1, dn2):
        '''
        Compares two DN's
        '''
        return cmp(cls.dnkey(dn1)[::-1], cls.dnkey(dn2)[::-1])
    cmpdn = classmethod(cmpdn)

    # @classmethod
    def explode_dn(cls, dn, notypes = False):
        '''
        Explodes DN into list
        '''
        res = []
        try:
            res = ldap.explode_dn(dn, notypes)
        except:
            pass

        return res
    explode_dn = classmethod(explode_dn)

    def getattr(self, dn, attrlist = None):
        '''
        Gets attributes of ldap object from Directory
        '''
        if dn == '':
            try:
                res = self.__con.search_s(dn,
                                      ldap.SCOPE_BASE,
                                      attrlist=attrlist,
                                      attrsonly=0)
#                if ads:
#                    res = self.__con.search_s(dn,
#                                          ldap.SCOPE_BASE)
#                else:
#                    res = self.__con.search_s(dn,
#                                          ldap.SCOPE_BASE,
#                                          attrlist=attrlist,
#                                          attrsonly=0)

            except:
                raise

        else:
            try:
                res = self.__con.search_s(dn,
                                          ldap.SCOPE_BASE,
                                          filterstr='(objectClass=*)',
                                          attrlist=attrlist,
                                          attrsonly=0)
            except:
                raise

        # self.data.update(res[0][1])
        return res[0][1]

    def getclassdef(self, classlist):
        '''
        Gets class definition of objectClass
        '''
        for cls in classlist:
            if cls not in self.classdef:
                clsdef = self.schema.get_obj(ldap.schema.ObjectClass, cls)
                if not clsdef is None:
                    self.classdef[cls] = {}
                    self.classdef[cls]['must'] = clsdef.must
                    self.classdef[cls]['may'] = clsdef.may

    def getschema(self):
        '''
        Returns schema objects
        '''
        ldap_url = ldapurl.LDAPUrl(self.__url)
        subschemasubentry_dn = self.__con.search_subschemasubentry_s(ldap_url.dn)
        if subschemasubentry_dn is None:
            subschemasubentry_entry = None
        else:
            if ldap_url.attrs is None:
                schema_attrs = ldap.schema.subentry.SCHEMA_CLASS_MAPPING.keys()
            else:
                schema_attrs = ldap_url.attrs
            subschemasubentry_entry = self.__con.read_subschemasubentry_s(subschemasubentry_dn, attrs=schema_attrs)

        if subschemasubentry_dn != None:
            parsed_sub_schema = ldap.schema.SubSchema(subschemasubentry_entry)
        else:
            parsed_sub_schema = None

        return parsed_sub_schema

    # @classmethod
    def __dnkey(cls, dn):
        '''
        Returns key used in the object index: dnidx
        '''
        dnkey = dn.lower()
        dnlist = ldap.explode_dn(dnkey)
        dnkey = ','.join(dnlist)
        return dnkey
    __dnkey = classmethod(__dnkey)

    # @classmethod
    def dnkey(cls, dn):
        '''
        Returns key used in the object index: dnidx
        '''
        return cls.__dnkey(dn)
    dnkey = classmethod(dnkey)


    def search(self, dn , scope = None, classlist = None, attrlist = None, ldapfilter = None):
        '''
        Searches Directory
        input:
            dn           : distinguishedName
            scope        : ldap search scope; one of: "one", "sub", "base"
            classlist    : list of types of objects to search for
            attrlist     : list of attributes to obtain; if 'None' will get all attributes
            ldapfilter   : plain ldapfilter for advanced searching, will overrule classlist

        result:
            ldapobject resultset
        '''
        # dns = []

        if scope == None or scope == "one":
            ldapscope = ldap.SCOPE_ONELEVEL
        elif scope == "sub":
            ldapscope = ldap.SCOPE_SUBTREE
        elif scope == "base":
            ldapscope = ldap.SCOPE_BASE
        else:
            raise ValueError(scope)

        if classlist == None or len(classlist) == 0:
            filterstr = '(objectClass=*)'
        else:
            fl = []
            for classtype in classlist:
                subfilter = "(objectClass="+classtype+")"
                fl.append(subfilter)
            filterstr = "(|"+"".join(fl)+")"

        if attrlist is not None:
            s = set(attrlist)
            # s.update(['objectClass', 'cn', 'ou', 'namingContexts'])
            s.update(self.MIN_ATTRLIST)
            attrlist = list(s)

        if ldapfilter is not None:
            filterstr = ldapfilter
        try:
            res = self.__con.search_s(dn
                                      , ldapscope
                                      , filterstr
                                      , attrlist
                                      , attrsonly=0)
        except:
            raise

        return res

    def printclassdef(self):
        '''
        Prints classdef cache
        '''
        for cls, attr in self.classdef.items():
            print "%20s : %r" % (cls, attr)


    def printdnidx(self):
        '''
        Prints object cache
        '''
        #pp(self.__dnidx)
        # pdn = GenericLdap.__getparent
        for key, value in sorted(self.dnidx.items()
                                 , cmp = lambda d1, d2: cmp(d1[0].lower()[::-1], d2[0].lower()[::-1])):
            # cmp = lambda d1, d2: cmp(pdn(d1[0]).lower()[::-1], pdn(d2[0]).lower()[::-1])
            print "\'%45s\' : %r" % (key, [id(obj) for obj in value])
            # if len(self.__dnidx[key][0].children) > 0:
            #    print '\n - ' + '\n - '.join([ldap.explode_dn(str(dn))[0] for dn in self.__dnidx[key][0].children])

class LdapCacheWrite(LdapCache):
    def __init__(self, ldapclass, url = None, user = '', passwd = '', cacert = None, mycert = None, mykey = None, debug = False):
        LdapCache.__init__(self, ldapclass, url, user, passwd, cacert, mycert, mykey)
        self.__con = self._getcon()
        # pp(self.__dict__)

    def setattr(self, dn, modlist):
        '''
        Sets / updates LDAP object attribute with value.
        '''

        try:
            self.__con.modify_s(dn, modlist)
        except ldap.LDAPError:
            raise

    def adddn(self, dn, modlist):
        '''
        Adds DN to Directory
        '''
        try:
            self.__con.add_s(dn, modlist)
        except:
            # raise ldap.ALREADY_EXISTS(dn) #IGNORE:E1102
            raise

        return dn

    def deldn(self, dn):
        '''
        Deletes DN from Directory
        '''
        try:
            self.__con.delete_s(dn)
        except:
            raise

        return dn

    def ChangeLdapPasswd(self, userdn, newpasswd):
        '''
        Changes a users password. Caller needs to have enough rights to do so.
        '''
        try:
            self.__con.passwd_s(str(userdn), None, newpasswd)
        except:
            raise
        
        return True
        # raise NotImplemented
        

    def ChangeMyLdapPasswd(self, userdn, oldpasswd, newpasswd):
        '''
        Changes my password
        '''
        try:
            self.__con.passwd_s(str(userdn), oldpasswd, newpasswd)
        except:
            raise
        
        return True
        # raise NotImplemented

class ADSLdapCache(LdapCache):
    def whoami(self):
        '''
        Not implemented in Active Directory
        '''
        res = LdapCache.whoami(self)
        # raise ldap.NOT_SUPPORTED
        # raise NotImplementedError
        return res
    
    


class ADSLdapCacheWrite(ADSLdapCache, LdapCacheWrite):
    def ChangePasswd(self, userdn, newpasswd):
        '''
        Changes ADS password of userdn if caller has enough rights. 
        Else will raise:
        ldap.INSUFFICIENT_ACCESS.

        Based on:
        http://support.microsoft.com/kb/269190
        '''
        modlist = []

        upasswd = self.unicodePwd(newpasswd)
        modlist.append((ldap.MOD_REPLACE|ldap.MOD_BVALUES, u'unicodePwd', upasswd))

        self.setattr(userdn, modlist)

    def ChangeMyPasswd(self, userdn, oldpasswd, newpasswd):
        '''
        Changes ADS password of caller. Can raise the following errors:
        ldap.CONSTRAINT_VIOLATION: Bad password, it does not follow the password policy. Check password policy.

        Based on:
        http://support.microsoft.com/kb/269190
        '''
        modlist = []

        uoldpasswd = self.unicodePwd(oldpasswd)
        unewpasswd = self.unicodePwd(newpasswd)

        modlist.append((ldap.MOD_DELETE|ldap.MOD_BVALUES, 'unicodePwd', uoldpasswd))
        modlist.append((ldap.MOD_ADD|ldap.MOD_BVALUES, 'unicodePwd', unewpasswd))

        self.setattr(userdn, modlist)

    def unicodePwd(cls, passwd):
        '''
        converts password (or any string) to unicodePwd.
        Needed to set password in Active Directory.
        '''
        upasswd = unicode('"' + passwd + '"', "iso-8859-1")
        upasswd = upasswd.encode("utf-16-le")
        return upasswd
    unicodePwd = classmethod(unicodePwd)


def GetRootDSE(url):
    '''
    Gets RootDSE entries.
    '''
    from ldapdict import GenericLdapDict

    root = {}
    lc = LdapCache(GenericLdapDict, url , "", "")
    root.update(lc.root.data)
    lc.disconnect()

    return root

def runtests():
    '''
    Runs tests
    '''
    import ldapdict
    ldapdict.runtests()


def _test():
    '''
    Execute doctest
    '''
    import doctest
    doctest.testmod(verbose = True)

if __name__ == "__main__":
    # _test()
    runtests()
    #rt()
    # prd()

