#!/usr/bin/env python
"""
Investigate user activities on OSG sites, for security auditing and
incident investigation.
"""

__rcsid__ = "$Id$"
__author__ = "Dan Gunter (dkgunter (at) lbl.gov)"

## Imports

import getpass
import socket
import sys
import time
#
from netlogger import nldate
from netlogger import util
from netlogger.analysis import loader
from netlogger.nllog import DoesLogging, get_logger, OptionParser

## Exceptions

class OptionError(Exception):
    pass

## Global variables

# Standard indentation
INDENT = "   "


## Signal handlers

def killHandler(signo, frame):
    """Signal handler for a graceful exit."""
    get_logger(__file__).warn("killed", msg="killed by signal", signo=signo)
    sys.exit(0)

## Classes

class Host:
    """Hold both IP and hostname, look up one from other lazily if
    needed. Use this to make the process of going back and forth
    between the two for display and processing as painless as possible.
    """
    def __init__(self, ip=None, hostname=None):
        if ip is None and hostname is None:
            raise ValueError("ip or hostname must be given")
        self._ip, self._hostname = ip, hostname

    @property
    def hostname(self):
        if self._hostname is None:
            self._forwardLookup()
        return self._hostname

    @property
    def ip(self):
        if self._ip is None:
            self._reverseLookup()
        return self._ip

    def _forwardLookup(self):
        self._hostname = socket.getfqdn(self._ip)

    def _reverseLookup(self):
        self._ip = socket.gethostbyname(self._hostname)

class Query(DoesLogging):
    RESULT_BATCH = 5000 # number of results per fetch
    def __init__(self, conn, begin=None, end=None, dn=None,
                 hostnames=False):
        DoesLogging.__init__(self)
        if begin is None or end is None:
            raise ValueError("Bad input to Query constructor")
        self.conn = conn
        self._param = dict(begin=begin, end=end, dn=dn)
        self.localtime = True
        self.tz = ""
        self.results = { }
        self._show_hostnames = hostnames

    def doQueries(self, *query_functions):
        self.error = "OK"
        self.log.info("queries.start")
        cursor = self.conn.cursor()
        cursor.arraysize = self.RESULT_BATCH
        self.results, status = { }, True
        self.initQueries(cursor)
        try:
            for q in query_functions:
                q(cursor)
        except Exception, err:
            self.log.error("query.failed", value=self._q)
            self.error = err
            status = False
        self.cleanupQueries(cursor)
        self.log.info("queries.end", status=(1, 0)[status])
        return status

    def initQueries(self, cursor):
        pass

    def cleanupQueries(self, cursor):
        pass

    def formatTime(self, t, strip=True):
        """Format time as a string."""
        if self.localtime:
            date = nldate.localtimeFormatISO(t)
            if not self.tz:
                self.tz = date[-6:]
        else:
            date = nldate.utcFormatISO(t)
            if not self.tz:
                self.tz = 'UTC'
        if strip: # strip subsec precision and tz from date
            date = date[:19]
        return date

    def noResults(self):
        """Return whether all result sets are empty
        (True) or not (False).
        """
        if not self.results:
            return True
        for value in self.results.values():
            if value:
                return False
        return True


class UserQuery(Query):

    # Standard field positions
    TIME1_FIELD = 0
    TIME2_FIELD = 1
    HOST_FIELD = 2

    # Database event names
    BRO_EVENT = "conn"

    def initQueries(self, cursor):
        self._gridftp_table = None
        self._gridftp_pid_table = None

    def run(self):
        return self.doQueries(self._jobs, self._gridftp, self._conn)

    def cleanupQueries(self, cursor):
        if self._gridftp_pid_table:
            cursor.execute("drop table if exists %s" % self._gridftp_pid_table)
        if self._gridftp_table:
            cursor.execute("drop table if exists %s" % self._gridftp_table)

    def displayText(self, ofile=sys.stdout):
        if self.noResults():
            ofile.write("Result set is empty\n")
            return
        ofile.write("\n* Jobs:\n")
        self.job_table(ofile)
        ofile.write("\n* GridFTP Transfers:\n")
        self.gridftp_table(ofile)
        ofile.write("\n* SRM transfers:\n")
        self.srm_table(ofile)
        ofile.write("\n* Connections:\n")
        self.conn_table(ofile)

    def job_table(self, ofile):
        """Print table of job results"""
        r = self.results.get('jobs', None)
        if not r:
            ofile.write("%sNo jobs\n" % INDENT)
            return
        titles = ("Start time (%s)" % self.tz, "End time", "Host",
                  "User (UID,GID)",
                  "Scheduler", "Executable")
        widths = (20, 20, 20, 15, 10, 20)
        ofile.write(_header(widths, titles))
        fmt = _format(widths)
        for row in r:
            ofile.write(INDENT)
            if self._show_hostnames:
                host = row[self.HOST_FIELD].hostname
            else:
                host = row[self.HOST_FIELD].ip
            ofile.write(fmt % (row[self.TIME1_FIELD],
                               row[self.TIME2_FIELD],
                               host,
                               row[3] + ' (' + row[4] + ',' + row[5] + ')',
                               row[6], "unknown"))

    def gridftp_table(self, ofile):
        """Print table of gridftp results"""
        r = self.results.get('gridftp', [ ])
        if not r:
            ofile.write("%sNone\n"  % INDENT)
        else:
            titles = ("Start time (%s)" % self.tz,
                      "End time", "Host", "Port", "User", "Filename")
            widths = (20, 20, 32, 6, 10, 10)
            ofile.write(_header(widths, titles))
            fmt = _format(widths)
            for row in r:
                ofile.write(INDENT)
                if self._show_hostnames:
                    host = row[self.HOST_FIELD].hostname
                else:
                    host = row[self.HOST_FIELD].ip
                ofile.write(fmt % (row[self.TIME1_FIELD],
                                   row[self.TIME2_FIELD],
                                   host,
                                   row[4], row[5], row[6]))

    def conn_table(self, ofile):
        """Print table of connections"""
        r = self.results.get('conn', None)
        if not r:
            ofile.write("%sNone\n" % INDENT)
        else:
            if self._show_hostnames:
                src_title = "Source host"
                dst_title = "Dest host"
                h_width = 30
            else:
                src_title = "Source IP"
                dst_title = "Dest IP"
                h_width = 16
            titles = ("First time (%s)" % self.tz,
                      "Last time", src_title, dst_title, "Port", "Count")
            widths = (20, 20, h_width, h_width, 8, 8)
            ofile.write(_header(widths, titles))
            fmt = _format(widths)
            for row in r:
                ofile.write(INDENT)
                if self._show_hostnames:
                    src, dst = row[2].hostname, row[3].hostname
                else:
                    src, dst = row[2].ip, row[3].ip
                ofile.write(fmt % (row[self.TIME1_FIELD],
                                   row[self.TIME2_FIELD],
                                   src, dst, row[4], row[5]))

    def srm_table(self, ofile):
        """Print table of SRM results"""
        r = self.results.get('srm', None)
        if not r:
            ofile.write("%sNone\n" % INDENT)
        else:
            ofile.write("%sTBD\n\n" % INDENT)

    def _conn(self, cursor):
        """Get connection info.
        Precondition: This comes AFTER all sections with hosts
        to contribute, such as _jobs(), _srm() and _gridftp()
        Return: First-time, last-time, source-ip, dest-ip, #conn
        """

        # Make a table of distinct hosts used in other sections
        host_tbl, tbl_name = { }, "host_tmp"
        self._q = "CREATE temporary table %s (host varchar(32))" % tbl_name
        self.log.debug("conn.query", value=self._q)
        cursor.execute(self._q)

        # Get unique hosts
        for section in 'gridftp', 'jobs', 'srm':
            for tuple in self.results.get(section, ()):
                host = tuple[self.HOST_FIELD].ip
                host_tbl[host] = 1
        if not host_tbl:
            return () # no hosts

        # build insert statement, and execute it
        insert_stmt = "INSERT INTO %s values " % tbl_name
        insert_stmt += ', '.join(["('" + s + "')" for s in host_tbl.keys()])
        self._q = insert_stmt
        self.log.debug("conn.query", value=self._q)
        cursor.execute(self._q)

        # Join hosts with Bro logs to get connections
        kw = self._param.copy()
        r = [ ]
        # Do a union of the select with the host in the source or dest
        for direction in ('sip','dip'), ('dip','sip'):
            kw['ip1'] = direction[0]
            kw['ip2'] = direction[1]
            self._q = """select min(e.time) as first, max(e.time) as last,
                     sip.value sip, dip.value dip, dp.value dp, count(*) as total
                     from host_tmp h
                     join attr %(ip1)s on h.host = %(ip1)s.value
                     join event e on e.id = %(ip1)s.e_id
                     join attr %(ip2)s on e.id = %(ip2)s.e_id
                     join attr dp on e.id = dp.e_id
                     where
                     sip.name = 'sip' and dip.name = 'dip' and dp.name = 'dp' and
                     (e.name = 'conn' and e.startend  = 0)  and
                     e.time >= %(begin)lf and e.time < %(end)lf
                     group by sip, dip, dp
                     order by min(e.time)
                     limit 10""" % kw
            self.log.debug("conn.query", value=self._q)
            cursor.execute(self._q)
            rows = cursor.fetchall()
            for row in rows:
                result = list(row)
                result[0] = self.formatTime(row[0])
                result[1] = self.formatTime(row[1])
                result[2] = Host(ip=result[2])
                result[3] = Host(ip=result[3])
                r.append(result)

        self.results['conn'] = r

    def _jobs(self, cursor):
        """Get job info"""
        # Make a temporary table with info from
        # the gatekeeper
        self._q = """create temporary table gk_tmp
         select e.time, -1 end_time, h.value host, j.value JM_ID,
         p.value PID
        from event e
        join attr h on e.id = h.e_id
        join dn d on e.id = d.e_id
        join ident j on e.id = j.e_id
        join ident p on e.id = p.e_id
        where
        e.name = 'globus.gk.info' and
        h.name = 'host' and
        j.name = 'jm' and
        p.name = 'process' and
        e.time >= %(begin)lf and e.time < %(end)lf and
        d.value = '%(dn)s'
        """ % self._param
        self.log.debug("jobs.query", value=self._q)
        cursor.execute(self._q)
        # Make a temporary table with end times
        self._q = """create temporary table gkend_tmp
                     select gk_tmp.JM_ID, e.time
                     from event e
                     join ident p on e.id = p.e_id
                     join gk_tmp on p.value = gk_tmp.PID
                     where
                     p.name ='process' and
                     e.name = 'globus.gk' and e.startend = 1"""
        self.log.debug("jobs.query", value=self._q)
        cursor.execute(self._q)
        # Join with info from the globus accounting logs
        self._q = """select gk_tmp.time start_time, gkend_tmp.time end_time,
                            gk_tmp.host host, u.value,
                            uu.value, ug.value, st.value
                     from event e
                     join ident jmid on e.id = jmid.e_id
                     join gk_tmp on jmid.value = gk_tmp.JM_ID
                     join gkend_tmp on jmid.value = gkend_tmp.JM_ID
                     join attr u on u.e_id = e.id
                     join ident uu on uu.e_id = e.id
                     join ident ug on ug.e_id = e.id
                     join attr st on st.e_id = e.id
                     where
                     e.name = 'globus.acct.job' and
                     jmid.name = 'jm' and
                     u.name = 'user' and
                     uu.name = 'user' and
                     ug.name = 'group' and
                     st.name = 'sched.type'"""
        self.log.debug("jobs.query", value=self._q)
        cursor.execute(self._q)
        r = [ ]
        while 1:
            rows = cursor.fetchmany()
            if not rows:
                break
            for row in rows:
                result = list(row)
                result[0] = self.formatTime(row[0])
                result[1] = self.formatTime(row[1])
                result[self.HOST_FIELD] = Host(ip=row[2])
                r.append(result)
        self.results['jobs'] = r

    def _gridftp(self, cursor):
        """Create a table of gridftp PID's.
        Then join this with selected gridftp_auth events
        Return: Start time, End_time, Host, PID, Port, User, Filename
        """
        # Create table (cannot be temporary because of re-use)
        # Should drop it before exit
        self._gridftp_pid_table = "gridftp_pid_%d" % int(time.time() * 1000)
        s = ("select p.value as pid from event e"
             " join dn on e.id = dn.e_id join"
             " attr p using(e_id)"
             " where p.name = 'PID' and dn.value = '%(dn)s'"
             " and e.time >= %(begin)lf and e.time < %(end)lf" % self._param)
        cursor.execute("create table %s (pid char(8)) %s" % (
                self._gridftp_pid_table, s))
        self._param['tmp_table'] = self._gridftp_pid_table
        # Select each value
        def attr_select(attr, event, startend, kw):
            """Build a select statement for an attribute"""
            kw['attr'], kw['event'] = attr, event
            if startend < 2:
                kw['startend'] = " and event.startend = %d" % startend
            else:
                kw['startend'] = ""
            q = ("(select x.value as %(attr)s, y.value pid"
                 " from attr y join attr x using (e_id)"
                 " join event on event.id = x.e_id"
                 " where x.name = '%(attr)s' and y.name ="
                 " 'PID' and event.name = '%(event)s'%(startend)s and"
                 " y.value in (select pid from %(tmp_table)s))" % kw)
            return q
        # Get start/end times
        s = ("(select p1.value pid, e2.time time from event e1 "
             " join dn on e1.id = dn.e_id"
             " join attr p1 on p1.e_id = dn.e_id "
             " join attr p2 on p2.value = p1.value "
             " join event e2 on e2.id = p2.e_id "
             " where dn.value = '%(dn)s' and"
             " (e1.time >= %(begin)lf and e1.time < %(end)lf) and"
             " e1.name = 'gridftp_auth.conn.auth.dn' and "
             " p1.name = 'PID' and p2.name = 'PID' and"
             " e2.name = 'gridftp_auth' and e2.startend = %%d)"
             " as e%%d" % self._param)
        start_time, end_time = s % (0, 1), s % (1, 2)
        # Build select statement
        select_list = ['e1.time start_time', 'e2.time end_time', 'a.pid']
        from_list = [ start_time, end_time ]
        pfx = 'gridftp_auth.'
        for alias, attr, event, startend in (
            ('a','host', pfx + "conn", 0),
            ('b','port', pfx + "conn", 0),
            ('c', 'user', pfx + "conn.auth.user", 2),
            ('d', 'filename', pfx + "conn.transfer", 0)):
            select_list.append("%s.%s" % (alias, attr))
            from_list.append(attr_select(
                    attr, event, startend, self._param) + " as %s" % alias)
        select_str = ", ".join(select_list)
        join_str = from_list[0]
        for clause in from_list[1:]:
            join_str += " join " + clause + " using(pid)"
        self._gridftp_table = "gridftp_%d" %  int(time.time() * 1000)
        self._q =  "create table %s select %s from %s" % (
            self._gridftp_table, select_str, join_str)
        # Get results (also save table)
        cursor.execute(self._q)
        cursor.execute("select * from %s" % self._gridftp_table)
        r = [ ]
        while 1:
            rows = cursor.fetchmany()
            if not rows:
                break
            for row in rows:
                result = list(row)
                result[0] = self.formatTime(result[0])
                result[1] = self.formatTime(result[1])
                # swap host and PID field positions
                result[self.HOST_FIELD], result[3] = \
                    Host(hostname=result[3]), result[self.HOST_FIELD]
                r.append(result)
        self.results['gridftp'] = r

    def old_table(self, ofile):
        titles = ("Time (%s)" % self.tz, "User", "Group",
            "GRAM ID", "JobManager ID", "Scheduler ID")
        widths = (20, 6, 6, 16, 24, 8)
        ofile.write(_header(widths, titles))
        fmt = _format(widths)
        for row in self.results:
            ofile.write(fmt % (row[0], row[2], row[3],
                               row[4], row[5], row[6]))

    def _accessTmp(self, cursor):
        self._q = """create temporary table access_tmp (id integer(12),
                                   time double)
        select e.id, e.time
        from event e
          join dn on e.id = dn.e_id
        where dn.value = '%(dn)s' and
              e.name = 'globus.acct.job' and
              e.time >= %(begin)lf and
              e.time < %(end)lf""" % self._param
        self.log.debug("access.query", value=self._q)
        cursor.execute(self._q)

    def _acct_select(self, cursor):
        self._q = """select a.time as time,
          u.value as user,
          g.value as user_group,
          r.value as gram_id,
          j.value as jm_id,
          s.value as sched_id
          t.value as sched_type
        from access_tmp a
          join ident u on a.id = u.e_id
          join ident g on a.id = g.e_id
          join ident r on a.id = r.e_id
          join ident j on a.id = j.e_id
          join ident s on a.id = s.e_id
          join attr t on  a.id = s.e_id
        where u.name = 'user'
          and g.name = 'group'
          and r.name = 'gram'
          and j.name = 'jm'
          and s.name = 'sched'
          and t.name = 'sched_type'"""
        self.log.debug("acct.query", value=self._q)
        cursor.execute(self._q)
        while 1:
            rows = cursor.fetchmany()
            if not rows:
                break
            for row in rows:
                result = list(row)
                # format time as a string
                if self.localtime:
                    date = nldate.localtimeFormatISO(row[0])
                    if not self.tz:
                        self.tz = date[-6:]
                else:
                    date = nldate.utcFormatISO(row[0])
                    if not self.tz:
                        self.tz = 'UTC'
                # strip subsecond precision and timezone from date
                result[0] = date[:19]
                # add to results
                self.results.append(result)

class ListQuery(Query, DoesLogging):
    def run(self):
        return self.doQueries(self._listUsers)

    def _listUsers(self, cursor):
        self._q = """select min(time) as first, max(time) as last, dn.value
        from event e join dn on e.id = dn.e_id
        where e.time >= %(begin)lf and e.time < %(end)lf
        group by dn.value order by time""" % self._param
        self.log.debug("list.query", value=self._q)
        cursor.execute(self._q)
        self.results[''] = [ ]
        while 1:
            rows = cursor.fetchmany()
            if not rows: break
            for row in rows:
                result = list(row)
                result[0] = self.formatTime(row[0])
                result[1] = self.formatTime(row[1])
                self.results[''].append(result)


    def displayText(self, ofile=sys.stdout):
        if self.noResults():
            ofile.write("Result set is empty\n")
            return
        titles = ("First time (%s)" % self.tz, "Last time", "DN")
        widths = (20, 20, 16)
        ofile.write(_header(widths, titles))
        fmt = _format(widths)
        for row in self.results['']:
            ofile.write(fmt % (row[0], row[1], row[2]))

## Functions

def _header(widths, titles, indent=INDENT):
    """Print a formatted header line
    """
    fmt = _format(widths)
    underlines = ['-'*len(t) for t in titles]
    line1 = fmt % tuple(titles)
    line2 = fmt % tuple(underlines)
    return indent + line1 + indent + line2

def _format(widths):
    return ' '.join(["%%-%ds" % n for n in widths]) + '\n'

def _formatDate(ts, localtime=True):
    if localtime:
        s = nldate.localtimeFormatISO(ts)
    else:
        s = nldate.utcFormatISO(ts)
    # if sub-seconds are zero, remove them
    p = s.find(".000000")
    if p > 0:
        s = s[:p] + s[p+7:]
    return s

def main():
    usage = "%prog [options] UserDN"
    desc = ' '.join(__doc__.split())
    prog = usage.split()[0]
    parser = OptionParser(usage=usage, description=desc)
    avail_url_str = ", ".join(getAvailURL())
    parser.add_option('-d', '--db', default=None,
                action='store', dest='db', metavar='DBNAME',
                help="Database to use (required, except for sqlite)")
    parser.add_option('-l', '--list', action='store_true', default=False,
                      dest="list_users",
                      help="List all distinct users in time range")
    parser.add_option('-n', '--numeric', action="store_true", default=False,
                      dest="numeric", help="Show numeric IP addresses. "
                      "No attempt will be made to look up hostnames, "
                      "and existing hostnames will be reverse-mapped "
                      "to IP addresses")
    parser.add_option('-p', '--password', action='store_true', dest="password",
                      help="Prompt for database password")
    parser.add_option('-t', '--timerange', default=None,
                      action='store', dest='timerange', metavar="TIME-RANGE",
                      help="Time range for query, as START::END. "
                      "Format for START, END is ISO8601, numeric or English "
                      "like 'yesterday' or '2 weeks ago'. "
                      "See nl_date for details (required)")
    parser.add_option('-u', '--url', default=None,
                 action='store', dest='url', metavar='URL',
                 help="Connect to database server at URL. "
                 "One of: %s (required" %  avail_url_str)
    options, args = parser.parse_args()
    log = get_logger(__file__)  # Should be first done, just after parsing args
    if not options.list_users:
        if len(args) != 1:
            parser.error("UserDN is required")
    missing = checkRequiredOptions(options)
    if missing:
        arg_list = ", ".join(missing)
        parser.error("Missing required arguments: %s" % arg_list)
    # Parse timerange (before trying to connect)
    try:
        start_date, end_date = options.timerange.split('::', 1)
    except ValueError:
        parser.error("Time range must be in format START::END, e.g. "
                     "2009-01::2009-03")
    try:
        fmt, start_ts = nldate.guess(start_date)
    except ValueError, E:
        parser.error("Bad format '%s' for START date: %s" % (start_date, E))
    if fmt == nldate.UNKNOWN:
        parser.error("Unknown format for START date '%s'" % start_date)
    try:
        fmt, end_ts = nldate.guess(end_date)
    except ValueError, E:
        parser.error("Bad format '%s' for END date: %s" % (end_date, E))
    if fmt == nldate.UNKNOWN:
        parser.error("Unknown format for END date '%s'" % end_date)
    # Set up signal handlers
    util.handleSignals((killHandler, ('SIGTERM', 'SIGINT', 'SIGUSR1',
                                     'SIGUSR2', 'SIGHUP')))
    # Connect to database
    log.debug("connect.start", url=options.url)
    try:
        dbmod = loader.getDbForURL(options.url)
    except ValueError:
        log.debug("connect.end", status=-1, url=options.url)
        parser.error("Bad URL format for '%s'" % options.url)
    if dbmod is None:
        log.debug("connect.end", status=-1, url=options.url)
        parser.error("Cannot load DB module for URL '%s'" % options.url)
    dsn, conn_kw = loader.extractConnKeywords(options.url)
    if options.password:
        passwd = getpass.getpass()
        conn_kw['passwd'] = passwd
    try:
        log.info("connect.db.start", db=options.db, url=options.url)
        conn = loader.connect(dbmod, conn_kw=conn_kw,
                              dsn=dsn, dbname=options.db)
        log.info("connect.db.end", status=0)
    except util.DBConnectError, err:
        log.info("connect.db.end", status=-1, msg=err)
        return 1 # exit
    # Create and run query
    log.info("run.start")
    status = 0
    ofile = sys.stdout
    kw = dict(begin=start_ts, end=end_ts, hostnames=not options.numeric)
    bd, ed = _formatDate(start_ts), _formatDate(end_ts)
    if  options.list_users:
        ofile.write("\nUsers between %s and %s\n" % (bd, ed))
        query = ListQuery(conn, **kw)
        success = query.run()
        if success:
            query.displayText(ofile=ofile)
    else:
        dn = kw['dn'] = args[0]
        ofile.write("\nActivity for '%s'\nbetween %s and %s\n" % (
                    dn, bd, ed))
        query = UserQuery(conn, **kw)
        success = query.run()
        if success:
            query.displayText(ofile=ofile)
    if not success:
        log.error("query.error", msg=query.error)
        status = -1
    log.info("run.end", status=status)
    return status

def checkRequiredOptions(options):
    """Return a list of missing options.

    The list will be empty if everything is OK.
    """
    result = [ ] # OK
    if options.url is None:
        result.append("-u URL")
    if options.db is None and (options.url is None or not
                               options.url.strip().startswith("sqlite://")):
        result.append("-d DBNAME")
    if options.timerange is None:
        result.append("-t START:END")
    return result

def getAvailURL():
    """Get a list of the URL patterns for available database modules.

    Returns a list of zero or more items.
    """
    avail_url = [ ]
    for name in loader.AVAIL_DB:
        if name == 'test':
            continue
        elif name == 'sqlite':
            avail_url.append("%s:///path/to/file" % name)
        else:
            avail_url.append("%s://[user@]host[:port]" % name)
    return avail_url

if __name__ == '__main__':
    sys.exit(main())
