#!/usr/bin/env python
__author__    = 'Kurt Schwehr'
__version__   = '$Revision: 11972 $'.split()[1]
__date__      = '$Date: 2009-05-26 14:11:37 -0400 (Tue, 26 May 2009) $'.split()[1]
__revision__  = __version__
__date__      = '$Date: 2009-05-26 14:11:37 -0400 (Tue, 26 May 2009) $'.split()[1]
__copyright__ = '2008'
__license__   = 'GPL v3'
__doc__='''
Listen for connections.  When it receives data, shove the data into a postgis server.

Expects normalized AIS messages.

nohup ./ais-net-to-postgis -a 10.0.3.5 -a 208.77.188.166 -a 127.0.0.1 -I 0.0.0.0 -s "3 hours ago" -S "4 hours ago" -c 600  &

@var __date__: Date of last svn commit
@status: under development
@undocumented: __doc__ parser

@todo: make this work with any database
@todo: detect when one of the child threads dies from an exception/traceback

@todo: allow the db to get synced on startup, so we do not have to wait for the first vessel traffic.
@todo: perhaps also periodically rebuild the cache if nothing going on if there are cache entries
so that they can expire.
'''

import sys

import time
import datetime

import socket
import thread
import select

import traceback
import exceptions

#import psycopg2
import magicdate

import ais
import ais.sqlhelp
import aisutils.database
import aisutils.uscg

######################################################################

def dateStr():
    '''
    String representing the day so that it sorts correctly
    @return: yyyy-mm-dd
    @rtype: str
    '''
    t = time.gmtime()
    d = '%04d-%02d-%02d' % t[:3]
    return d

######################################################################

class DatabaseHandler:
    '''
    Queue handling for the database
    '''
    def __init__(self, connection, dbUpdateInterval=5., threshold=1, verbose=False, skipDB=False, 
                 limitPoints=10
                 , track_start_time_limit=None #'1 hour ago'
                 , last_position_time_limit=None #'6 hours ago'
                 , cleanTime=30
                 ):
        #, maxLoops=5
        '''
        @param connection: database connection
        @param dbUpdateInterval: how many seconds between database updates
        @param threshold: How many messages in the queue before they go to the db?
        @param skipDB: Do not actually talk to the database.  For debugging.
        @param limitPoints: max number of points in a ship track.  None for no limit
        @param : a string limiting the length of ship tracks wrt time.  For example: "6 hours ago".  None for no limit
        @param : a string limiting the length of ship tracks wrt time.  For example: "6 hours ago".  None for no limit
        @param cleanTime: seconds between running the database cleanup for the track lines
        '''
        #@param maxLoops: how many loops with data below the threshold before we commit the data?

        if not skipDB: 
            self.cx = connection
            self.cu = connection.cursor()
        else:
            self.cx = None
            self.cu = None
        import Queue
        self.q = Queue.Queue()
        self.running = True
        self.dbUpdateInterval = dbUpdateInterval
        self.threshold = threshold
        self.v = verbose
        self.stopped = False  # Set to true when this class is all done and committed
        self.skipDB = skipDB

        self.limitPoints = limitPoints
        #self.timeLimit   = timeLimit
        self.track_start_time_limit = track_start_time_limit
        self.last_position_time_limit = last_position_time_limit
        self.cleanTime = cleanTime
        if self.v:
            sys.stderr.write('Database handler init\n')
            sys.stderr.write('  track_start: %s\n' % str(self.track_start_time_limit))
            sys.stderr.write('  last_position: %s\n' % str(self.last_position_time_limit))

    def stop(self):
        self.running = False
        if self.v:
            sys.stderr.write('Database handler stop scheduled\n')

    def commit(self):
        '''
        Slow loop that commits groups of position reports to the database
        @todo: Build the lines/shiptracks from the position reports for the ships with new positions.
        @todo: Where do I decimate reports received from multiple receivers?
        @todo: Need to track the ships in the put.
        '''
        q = self.q
        cx = self.cx
        cu = self.cu
        skipDB = self.skipDB
        #v = self.v
        #print 'VERBOSITY',self.v

        vesselsSeen = set()

        #limitPoints = 10 # max number of points per vessel

        size = q.qsize()
        if size >= self.threshold:
            # Don't try to flush incoming messages
            if self.v:
                sys.stderr.write('pulling from queue '+str(size)+'\n')
            for i in range(size):
                sqlStr, vessel = q.get()
                if vessel is not None:
                    vesselsSeen.add(vessel)
                if self.v:
                    sys.stderr.write('exec in commit: '+str(i)+' '+sqlStr+'\n')
                if not skipDB:
                    try:
                        cu.execute(sqlStr)
                    except Exception, e:
                        sys.stderr.write('*** exception on sql execute.\n   '+sqlStr+'\n')
                        sys.stderr.write('   Exception:' + str(type(Exception))+'\n')
                        sys.stderr.write('   Exception args:'+ str(e)+'\n')
                        traceback.print_exc(file=sys.stderr)
                        continue
            if self.v:
                sys.stderr.write('committing...\n')
            if skipDB:
                sys.stderr.write(' skipping commitment\n')
            else:
                cx.commit()
                if self.v:
                    sys.stderr.write('recalculate ship tracks vessels seen\n')
                    sys.stderr.write('  '+str(vesselsSeen)+'\n')
                if len(vesselsSeen)>0:

                    tzoffset = datetime.timedelta(seconds=time.timezone) # Always work in utc.  Magicdate is local



                    track_start_time = magicdate.magicdate(self.track_start_time_limit) + tzoffset
                    #print '\n\n...\nFIX START TIME TRK:',self.track_start_time_limit,tzoffset,'  ->  ',track_start_time
                    aisutils.database.rebuild_track_lines  (cx, vessels=vesselsSeen, startTime = track_start_time, verbose=self.v)

                    last_pos_start_time = magicdate.magicdate(self.last_position_time_limit) + tzoffset
                    #print '\n\n...\nFIX START TIME POS:',self.last_position_time_limit,tzoffset,'   ->  ', last_pos_start_time
                    aisutils.database.rebuild_last_position(cx, vesselsClassA=vesselsSeen, startTime = last_pos_start_time, verbose=self.v)

    def clean(self):
        '''
        Run through the vessel list so that aging works
        '''


        if self.track_start_time_limit is None and self.last_position_time_limit is None:
            # no time expirations, so don't clean
            return

        tzoffset = datetime.timedelta(seconds=time.timezone) # Do this everytime to make sure daylight savings does not nail us

        if self.v:
            sys.stderr.write('\n\nCLEANING...\n\n')
            sys.stderr.write('  utcnow ... %s\n' % str(datetime.datetime.utcnow()))

        if self.track_start_time_limit:
            startTime = magicdate.magicdate(self.track_start_time_limit) + tzoffset
            if self.v:
                sys.stderr.write('  clean track_start %s (%s)\n' % (str(startTime),self.track_start_time_limit))

            #print '\n\n...\nFIX 2 START TIME TRK:',self.track_start_time_limit,tzoffset,'  ->  ',startTime

            aisutils.database.rebuild_track_lines  (self.cx, startTime=startTime, verbose=self.v)
            print 'DONE track_start'

        if self.last_position_time_limit:
            startTime = magicdate.magicdate(self.last_position_time_limit) + tzoffset
            if self.v:
                sys.stderr.write('  clean last_position %s (%s)\n' % (str(startTime),self.last_position_time_limit))
            #print '\n\n...\nFIX 2 START TIME POS:',self.last_position_time_limit,tzoffset,'   ->  ', startTime
            aisutils.database.rebuild_last_position(self.cx, startTime = startTime, verbose=self.v)
        if self.v:
            sys.stderr.write('DONE CLEANING...\n\n')
        
    def handler(self, unused=None):
        '''
        top level loop thread
        '''
        nextClean = datetime.datetime.utcnow() # Do an immediate clean # + datetime.timedelta(seconds=30)
        while self.running:
            time.sleep(self.dbUpdateInterval)

            try:
                self.commit()
            except Exception, e:
                sys.stderr.write("*** exception on commit in handler. \n")
                sys.stderr.write('   Exception:' + str(type(Exception))+'\n')
                sys.stderr.write('   Exception args:'+ str(e)+'\n')
                traceback.print_exc(file=sys.stderr)
                continue

            if datetime.datetime.utcnow() > nextClean:
                self.clean()
                nextClean = datetime.datetime.utcnow() + datetime.timedelta(seconds=self.cleanTime)
        sys.stderr.write('Database handler stopping... begin final commit\n')
        self.commit()
        self.stopped = True
        
                

class HandleAisConnection:
    '''
    Handles the incoming socket
    '''
    def __init__(self, dataSocket, dbQueue, options, dbType='postgres'):
        '''
        @param dbQueue: Queue object to push SQL messages onto
        '''
        self.options = options
        sys.stderr.write('hac options in __init__ '+str(self.options)+'\n')
        self.dataSocket = dataSocket
        self.running = True
        self.dbQueue = dbQueue
        self.dbType = dbType
        try:
            self.timeout = options.timeout
        except:
            self.timeout = 1.
        self.buf = None

    def handler(self, unused=None):
        sys.stderr.write('hac handler\n')
        v = self.options.verbose
        if v:
            sys.stderr.write('handler started.  VERBOSE\n')
        
        while self.running:
            #if v: sys.stderr.write('calling select\n')
            readersready, outputready, exceptready = select.select([self.dataSocket,],[],[],self.timeout)
            # FIX: how do we know that a socket is closed?
            #if v: sys.stderr.write('select returned\n')
            for sock in readersready:
                data = sock.recv(1024)
                #if v: sys.stderr.write('read '+str(len(data))+' bytes: '+data+'\n')

                if len(data) == 0:
                    # Means we didn't time out and the socket is closed
                    # FIX: save the connection info and say who left
                    if v:
                        sys.stderr.write('shutting down ais connection handler\n')
                    self.running = False
                    continue

                if self.buf is not None:
                    data = self.buf+data
                    if v:
                        sys.stderr.write('new buf: '+data+'\n')
                    self.buf = None
                nmeaStrs = data.split('\n')
                if data[-1] not in ('\n','\r'):
                    #if v: sys.stderr.write('putting last entry in buffer.  last char is "'+data[-1]+'"\n')
                    self.buf = nmeaStrs.pop()

                if len(data) == 0:
                    #if v: sys.stderr.write('nothing to process... partial message probably ion the buffer\n')
                    continue

                #if v: sys.stderr.write('processing '+str(len(nmeaStrs))+' message(s)\n')
                for msg in nmeaStrs:
                    if len(msg) == 0:
                        continue
                    if 'AIVDM'!= msg[1:6]:
                        if v:
                            sys.stderr.write('Skipping non ais message "'+msg+'"\n')
                        continue
                    else:
                        if v:
                            sys.stderr.write('processing ais message ... '+msg+'\n')
                        pass

                    uscgMsg = aisutils.uscg.UscgNmea(msg)

                    if uscgMsg.totalSentences != 1:
                        if v:
                            sys.stderr.write('Can not handle unnormalized messages. Tot sentences: '+str(uscgMsg.totalSentences)+'\n')
                            sys.stderr.write('  msg:'+msg+'\n')
                        continue
                    
                    if uscgMsg.msgTypeChar not in ais.msgModByFirstChar:
                        if v:
                            sys.stderr.write('Can not process message starting with '+uscgMsg.msgTypeChar+'\n')
                        continue
                    aismsg = ais.msgModByFirstChar[uscgMsg.msgTypeChar]
                    #if v: sys.stderr.write('bits'+str(uscgMsg.contents)+'\n') #msgStr
                    bv = ais.binary.ais6tobitvec(uscgMsg.contents)
                    #if v: sys.stderr.write('bv '+str(bv)+'\n')
                    try:
                        msgDict = aismsg.decode(bv)
                    except Exception, e:
                        #logging.exception('')
                        #sys.stderr.write("*** ais-net-to-postgis handler exception.  Skipping mesg\n")
                        #sys.stderr.write('   Failed to decode the message\n')
                        #sys.stderr.write('   Msg:' + str(msg)+'\n')
                        #sys.stderr.write('   Exception:' + str(type(Exception))+'\n')
                        #sys.stderr.write('   Exception args:'+ str(e)+'\n')
                        #traceback.print_exc(file=sys.stderr)
                        sys.stderr.write('   Dropping bad msg and calling continue\n')
                        #sys.stderr.write(' startRecvThread - crash count: '+str(crashCount)+'\n')
                        continue

                    ins = aismsg.sqlInsert(msgDict, dbType=self.dbType) #dbType='postgres')

                    cg_sec = uscgMsg.cg_sec
                    cg_timestamp = uscgMsg.sqlTimestampStr
                    cg_station = uscgMsg.station

                    if None != cg_sec:       ins.add('cg_sec',       cg_sec)
                    if None != cg_timestamp: ins.add('cg_timestamp', cg_timestamp)
                    if None != cg_station:   ins.add('cg_r',         cg_station)

                    #if v: sys.stderr.write('queuing ais SQL insert command: ' + str(ins)+'\n')

                    # Figure out if the vessel track needs to be updated
                    vessel = None
                    if aismsg.dbTableName in ('position',):
                        vessel = msgDict['UserID']
                    self.dbQueue.put((str(ins), vessel))

        
class PassThroughServer:
    '''Receive data from a socket and write the data to all clients that
    are connected.  Starts two threads and returns to the caller.
    '''
    nmeaInputs = {} # indexed by socket... handles if we have partial text
    def __init__(self, options, dbHandler):
        '''
        @param options: understands timeout (float in seconds)
        '''
        self.clients = []
        self.options = options
        self.count = 0
        try:
            self.timeout = options.timeout
        except:
            self.timeout = 1.
        try:
            self.hosts_allow = options.hosts_allow
        except:
            self.hosts_allow = None
        self.running = True
        self.hacs = []
        self.dbHandler = dbHandler

    def start(self):
        '''
        start the thread
        FIX: cleanup and use verbose flag
        '''
        sys.stderr.write('starting threads\n')
        thread.start_new_thread(self.connection_handler, (self,))
        sys.stderr.write('connection_handler started\n')
        return

        
    def connection_handler(self, unused=None):
        '''Do not use this.  Call start() instead.  This listens for
        connections and adds the new socket to the clients list.

        @bug: how can I get rid of unused?
        '''
        inHost = self.options.inHost
        inPort = self.options.inPort
        sys.stderr.write('starting incoming connection receiver: '+str(inHost)+':'+str(inPort)+'\n')

        serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        serversocket.bind((inHost, inPort))
        serversocket.listen(5)

        while self.running:
            (clientsocket, address) = serversocket.accept()
            sys.stderr.write('connect from '+str(clientsocket)+' '+str(address)+'\n')
            if self.hosts_allow is not None and address[0] not in self.hosts_allow:
                sys.stderr.write('not allowing connection from ... '+str(address)+'\n')
                continue
            else:
                sys.stderr.write('creating hac\n')
                hac = HandleAisConnection(clientsocket
                                          ,self.dbHandler.q
                                          ,self.options
                                          ,self.options.dbType)
                sys.stderr.write('creating thread\n')
                thread.start_new_thread(hac.handler, (self,))
                self.hacs.append(hac)

        # FIX: do cleanup for graceful shutdown
        # Stop all the hacs
        # Close and flush the database handler

    def stop(self):
        self.running = False

def main():
    from optparse import OptionParser

    # FIX: is importing __init__ safe?
    parser = OptionParser(usage="%prog [options]",
                          version="%prog "+__version__ + " ("+__date__+")")

    parser.add_option('-i', '--in-port', dest='inPort', type='int', default=31402
                      , help='Where the data comes from [default: %default]')
    parser.add_option('-I', '--in-host', dest='inHost', type='string', default='localhost'
                      , help='What host to read data from [default: %default]')
    parser.add_option('--in-gethostname', dest='inHostname', action='store_true', default=False
                      , help='Where the data comes from [default: %default]')

    parser.add_option('-a', '--allow', action='append', dest='hosts_allow'
                      , help='Add hosts to a list that are allowed to connect [default: all]')

    parser.add_option('-t', '--timeout', dest='timeout', type='float', default='300', 
                      help='Number of seconds to timeout after if no data [default: %default]')

    aisutils.database.stdCmdlineOptions(parser, 'postgres')

    parser.add_option('-v', '--verbose', dest='verbose', default=False, action='store_true'
                      , help='Make the test output verbose')

    parser.add_option('--dummy-db', dest='skipDB', default=False, action='store_true'
                      , help='Do not actually talk to database')

    parser.add_option('--time-limit-all', dest='timeLimitAll', type='str' #type='magicdate'
                      , default=None
                      , help='Limit all caches by one magic date time (e.g. "1 hour ago")'
                      +'-s and -S override this'
                      +' [default %default]')

    # These will get evaluated each time around
    parser.add_option('-s', '--track-start-time', dest='track_start', type='str' #type='magicdate'
                      , default=None
                      , help='magicdate - Oldest allowable time for a track line [default %default]')

    parser.add_option('-S', '--last-position-start-time', dest='last_position_start', type='str' #, type='magicdate'
                      , default=None
                      , help='magicdate - Oldest allowable time for a last position [default %default]')

    parser.add_option('-c', '--clean-time', dest='cleanTime', type='float' #, type='magicdate'
                      , default=30
                      , help='Time in seconds between database cleanup of the track lines [default %default]')

    (options, args) = parser.parse_args()

    v = options.verbose
    #options.uscg   = True # FIX: don't hack this in
    options.dbType = 'postgres' # FIX: don't force this

    if options.timeLimitAll is not None:
        if options.track_start is None:
            options.track_start = options.timeLimitAll
        if options.last_position_start is None:
            options.last_position_start = options.timeLimitAll
    
    if options.track_start is not None:
        try:
            m = magicdate.magicdate(options.track_start)
        except Exception:
            sys.exit('PARSE ERROR: track_start is not a valid magic date ... '+options.track_start)
        if m is None:
            sys.exit('track_start is not a valid magic date ... '+options.track_start)
    if options.last_position_start is not None:
        try:
            m = magicdate.magicdate(options.last_position_start)
        except Exception:
            sys.exit('PARSE ERROR: last_position_start is not a valid magic date ... '+options.last_position_start)
        if m is None:
            sys.exit('last_position_start is not a valid magic date ... '+options.last_position_start)

    #print 'ts:',options.track_start, type(options.track_start)
    #print 'lps:',options.last_position_start, type(options.last_position_start)
    #sys.exit("EARLY")

    cx = None
    if not options.skipDB:
        cx = aisutils.database.connect(options, dbType=options.dbType)


    sys.stderr.write('Creating dbHandler verbosity = %s\n' %v)
    dbHandler = DatabaseHandler(cx
                                , track_start_time_limit=options.track_start
                                , last_position_time_limit=options.last_position_start
                                , verbose=v
                                , skipDB=options.skipDB
                                , cleanTime=options.cleanTime
                                )
    
    if v:
        sys.stderr.write('hosts allowed: '+str(options.hosts_allow)+'\n')

    if options.inHostname:
        options.inHost = socket.gethostname()

    pts = PassThroughServer(options, dbHandler)
    pts.start()

    # Now start up the thread to send the messages to 
    thread.start_new_thread(dbHandler.handler, (None, ))

    timeout = options.timeout
    del(options) # remove global to force self.options
    i = 0
    running = True

    try: 
        while running:
            i += 1
            time.sleep(timeout)
            if v:
                sys.stderr.write('ping '+str(i)+'\n')
    except KeyboardInterrupt:
        if v:
            sys.stderr.write('\bshutting down...\n')

    # FIX: nuke all the connections?

    dbHandler.stop()
    while not dbHandler.stopped:
        time.sleep(.1)
    if v:
        sys.stderr.write('Finished cleaning up ... goodbye\n')
    

######################################################################
if __name__ == '__main__':
    main()
