#!/usr/bin/python

import sys
import re
import time
import argparse
import boto.s3
from multiprocessing import Pool
from functools import wraps
import signal

def S3Path(p):
    m = re.match(r's3://([^/]+)/(.+)', p)
    if not m:
        raise ValueError, 'invalid s3 url - should be "s3://bucketname/prefix"'
    return m.groups()
    
def Regex(r):
    try:
        return re.compile(r)
    except re.error, ex:
        raise ValueError, str(ex)
        
class Worker(object):
    pass

class LineBuffer(object):
    """Add readline() line buffering to a file handle only supporting read()"""
    def __init__(self, fin):
        self.fin = fin
        self.buf = []
        
    def __iter__(self):
        return self
    
    def readline(self):
        if len(self.buf) <= 1:
            more = self.fin.read().split('\n')
            if self.buf and more:
                # join across the boundary
                self.buf[0] = self.buf[0] + more.pop(0)
            self.buf.extend(more)
        
        if self.buf:
            return self.buf.pop(0)
        else:
            return ''
        
    def next(self):
        l = self.readline()
        if l:
            return l
        else:
            raise StopIteration
        
class Worker(object):
    def __init__(self, bucket_name, regex, args):
        s3 = boto.connect_s3()
        self.bucket = s3.get_bucket(bucket_name)
        self.regex = regex
        self.args = args
        
    def __call__(self, key_name):
        key = self.bucket.get_key(key_name)
        
        nbytes = nlines = nmatches = 0
        if self.args.verbose:
            print >> sys.stderr, 'Scanning: %s' % (key.name)
        
        if self.regex.pattern == '.':
            # optimise special case - handy for using conventional grep, for example
            key.get_contents_to_file(sys.stdout)
        else:
            for line in LineBuffer(key):
                nlines += 1
                nbytes += len(line)+1 # \n
                m = self.regex.search(line)
                if bool(self.args.invert) != bool(m): # XOR
                    if self.args.with_filename:
                        print '%s:%s' % (key.name, line)
                    else:
                        print line
                    nmatches += 1
                    
        return nbytes, nlines, nmatches

def initialize_worker(*args):
    global worker
    try:
        worker = Worker(*args)
    except KeyboardInterrupt:
        # leave cleanly on keyboard interrupt
        return
    
def call_worker(args):
    global worker
    try:
        return worker(args)
    except KeyboardInterrupt:
        # multiprocessing hangs on KeyboardInterrupt - workaround
        raise Exception

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('path', nargs=1, type=S3Path, help="eg. s3://bucketname/prefix")
    parser.add_argument('regex', nargs=1, type=Regex, help="regular expression to search for")
    parser.add_argument('--verbose', action='store_true', help="verbose output")
    parser.add_argument('-p', '--processes', type=int, default=6, help="number of processes to run")
    # grep options
    parser.add_argument('-v', '--invert', action='store_true', help="invert match (ie. return lines that do not match)")
    parser.add_argument('-H', '--with-filename', action='store_true', help="Print the file name for each match.")
    
    args = parser.parse_args()
    
    # setup
    nfiles = nlines = nbytes = nmatches = 0
    p = args.path[0]
    regex = args.regex[0]
    pool = Pool(processes=args.processes,
                initializer=initialize_worker,
                initargs=(p[0], regex, args))
    start = time.time()
    try:
        if args.verbose:
            print >> sys.stderr, 'Searching bucket: %s' % (p[0])
        s3 = boto.connect_s3()
        bucket = s3.get_bucket(p[0])
        
        keys = bucket.list(prefix=p[1])
        jobs = ( k.name for k in keys )
        # do the actual work
        for b, l, m in pool.imap(call_worker, jobs):
            nfiles += 1
            nlines += l
            nbytes += b
            nmatches += m
    except KeyboardInterrupt:
        pool.terminate()
        pool.join()
    else:
        pool.close()
        pool.join()
    
    end = time.time()
    print >> sys.stderr, 'Scanned: %d files, %d lines, %d bytes, %d matches, took %dms' % (nfiles, nlines, nbytes, nmatches, (end-start)*1000)

if __name__ == '__main__':
    main()
