#!/usr/bin/env python
""" Split one or more fastq files based on barcode sequence.
"""
from __future__ import division
import gzip
import optparse
import os
import re
import sys
#from fastq_utils import fastq_utils

__version__ = "0.1"
__author__ = "Lance Parsons"
__author_email__ = "lparsons@princeton.edu"
__copyright__ = "Copyright 2011, Lance Parsons"
__license__ = "BSD 2-Clause License http://www.opensource.org/licenses/BSD-2-Clause"

UNMATCHED = 'unmatched'

def main (argv=None):
    if argv is None:
        argv = sys.argv
    
    usage = "Usage: %prog [options] --bcfile barcodes.txt fastq_read1 [fastq_read2] [fastq_read3]"
    parser = optparse.OptionParser(usage=usage, version='%prog version ' + globals()['__version__'], description=globals()['__doc__'])
    
    required_group = optparse.OptionGroup(parser, "Required")
    required_group.add_option ('--bcfile', metavar='FILE', help='Tab delimited file: "Sample_ID <tab> Barcode_Sequence" (REQUIRED)')
    required_group.add_option ('--idxread', metavar='READNUM', type='int', help='Indicate in which read (1, 2, 3, etc.) to search for the barcode sequence (REQUIRED)')
    parser.add_option_group(required_group)
    
    output_group = optparse.OptionGroup(parser, "Output Options")
    output_group.add_option ('--prefix', default='', help='Prefix for output files')
    output_group.add_option ('--suffix', default=None, help='Suffix for output files (default based on --format)')
    output_group.add_option ('--galaxy', action='store_true', default=False, help='Produce "Galaxy safe" filenames by removing underscores (default: %default)')
    output_group.add_option ('-v', '--verbose', action='store_true', default=False, help='verbose output')
    parser.add_option_group(output_group)
    
    barcode_location_group = optparse.OptionGroup(parser, "Barcode Location")
    barcode_location_group.add_option ('--barcodes_at_end', action='store_true', default=False, help='Barcodes are at the end of the index read (default is at the beginning)')
    parser.add_option_group(barcode_location_group)
    
    matching_group = optparse.OptionGroup(parser, "Matching")
    matching_group.add_option ('--mismatches', default=0, type='int',  help='Number of mismatches allowed in barcode matching')
    parser.add_option_group(matching_group)
    
    input_group = optparse.OptionGroup(parser, "Input format")
    input_group.add_option ('--format', default='fastq', help='Specify format for sequence files (fasta or fastq)')
    input_group.add_option('--gzip', action='store_true', default=False, help='Force gzip format for input and output (default is auto based on input file extension)')
    parser.add_option_group(input_group)
    
    
    try:
        (options, args) = parser.parse_args(argv[1:])
        if len(args) < 1:
            parser.error('Must specify at least one sequence file')
        if not options.bcfile:
            parser.error('Must specify a barcodes file with "--bcfile" option')
        try: 
            int(options.idxread)
        except TypeError:
            parser.error('Must specify the index read number with "--idxread" option')
        if (int(options.idxread) < 1) or (int(options.idxread) > len(args)):
            parser.error('Invalid index read number ("--idxread") specified, must be one of the reads you specified (between 1 and %s)' % len(args))
    except SystemExit: # Prevent exit when calling as function
        return 2
    
    # Read barcodes files into dict
    barcode_dict = read_barcodes(options.bcfile)
    total_read_count = 0
    counts = {UNMATCHED: 0}
    for barcode in barcode_dict:
        counts[barcode] = 0
    # TODO Verbose: print barcode_dict
    
    # Determine if we should use gzip for input/output
    if options.suffix is not None: suffix = options.suffix
    else: suffix = '.%s' % options.format
    basename, extension = os.path.splitext(args[0])
    if extension == '.gz':
        options.gzip = True
        if options.suffix is None: suffix = '%s.gz' % suffix
        
    # Open filehandles for each read
    inputs = {}
    outputs = {}
    for i in xrange(0,len(args)):
        if options.gzip:
            inputs[i] = gzip.open(args[i], 'rb')
        else:
            inputs[i] = open(args[i], 'rb')
        first_line = inputs[i].readline().strip()
        id_format = determine_id_format(first_line[1:])
        # TODO print "\nFile %s: %s\nId format: %s" % (i, args[i], id_format)
        inputs[i].seek(0)
    
    # Open filehandles for each barcode (and unmatched), for each read
    for barcode in barcode_dict:
        outputs[barcode] = {}
        sample_id = barcode_dict[barcode]
        if options.galaxy:
            # Replace underscore to allow this to work with Galaxy
            sample_id = barcode_dict[barcode].replace("_","-")
        for i in inputs:
            if options.gzip:
                outputs[barcode][i] = gzip.open('%s%s-read-%s%s' % (options.prefix, sample_id, i+1, suffix), 'wb')
            else:
                outputs[barcode][i] = open('%s%s-read-%s%s' % (options.prefix, sample_id, i+1, suffix), 'wb')
    outputs[UNMATCHED] = {}
    for i in inputs:
        if options.gzip:
            outputs[UNMATCHED][i] = gzip.open('%s%s-read-%s%s' % (options.prefix, UNMATCHED, i+1, suffix), 'wb')  
        else:
            outputs[UNMATCHED][i] = open('%s%s-read-%s%s' % (options.prefix, UNMATCHED, i+1, suffix), 'wb')    
            
    # For each input line in index read, get index sequence
    for index_read in read_fastq(inputs[options.idxread-1]):
        total_read_count += 1
        barcode_length = len(barcode_dict.keys()[0])
        if options.barcodes_at_end:
            index_seq = index_read['seq'][-barcode_length:]
        else:
            index_seq = index_read['seq'][0:barcode_length]
            
        # Get matching barcode(s), if more than one, warn and set to unmatched
        best_match_barcodes = match_barcodes(index_seq, barcode_dict.keys(), options.mismatches)
        if (len(best_match_barcodes) == 1):
            barcode_match = best_match_barcodes[0]
        else:
            barcode_match = UNMATCHED
            if (len(best_match_barcodes) > 1) :
                sys.stderr.write('More than one barcode matches for %s, moving to %s category\n' % (index_read['seq_id'], UNMATCHED))
        counts[barcode_match] += 1
        
        # Get sequence record from each other read, assert id matches
        for readnum in xrange(0,len(inputs)):
            if readnum != options.idxread-1:
                read = read_fastq(inputs[readnum]).next()
                try:
                    assert(match_id(index_read['seq_id'], read['seq_id'], id_format))
                except AssertionError:
                    sys.stderr.write("Id mismatch: %s does not match %s" %(index_read['seq_id'], read['seq_id']))
            else:
                read = index_read
            # Output sequences into barcode/read file    
            outputs[barcode_match][readnum].write(fastq_string(read))
    print "Sample\tBarcode\tCount\tPercent"
    for barcode in sorted(barcode_dict, key=barcode_dict.get):
        print "%s\t%s\t%s\t%.2f%%" % (barcode_dict[barcode], barcode, counts[barcode], (counts[barcode]/total_read_count)*100 )
    print "%s\t%s\t%s\t%.2f%%" % (UNMATCHED, None, counts[UNMATCHED], (counts[UNMATCHED]/total_read_count)*100 )
    return 0

def read_barcodes(filename):
    '''Read barcodes file into dictionary'''
    barcode_dict = {}
    linenum = 0
    filehandle = open(filename, 'rb')
    for line in filehandle:
        linenum += 1
        line = line.strip()
        if line[0] != '#':
            (sample_id, barcode_sequence) = line.split('\t')
            if (sample_id is not None) and (barcode_sequence is not None):
                barcode_dict[barcode_sequence] = sample_id
            else:
                raise Exception("Unable to read barcode from line %s: '%s'" % (linenum, line))
    return barcode_dict

def match_barcodes(sequence, barcodes, mismatches):
    '''Find closest match(es) in barcodes to specified sequence with max number of mismatches'''
    best_distance = mismatches
    results = []
    for barcode in barcodes:
        if mismatches == 0:
            if (sequence == barcode):
                results.append(barcode)
        else:
            distance = hamming_distance(sequence, barcode)
            if (distance <= best_distance):
                best_distance = distance
                results.append(barcode)
    return results
    
def hamming_distance(s1, s2):
    assert len(s1) == len(s2)
    return sum(ch1 != ch2 for ch1, ch2 in zip(s1, s2))    

''' Supported types of Fastq IDs '''
ILLUMINA = 'illumina' # CASVA 1.8+, match up to space
STRIPONE = 'stripone' # Illumina CASAVA 1.7 and lower (/1 /2) and NCBI SRA (/f /r), match all but last character
OTHER = 'other'       # Other, match exactly

def determine_id_format(seq_id):
    '''Determine if the id is new illumina, old illumina (/1 /2 ...), sanger (/f /r), or other'''
    
    id_format = None
    # Illumina CASAVA 1.8+ fastq headers use new format
    read_id_regex = re.compile(r'(?P<instrument>[a-zA-Z0-9_-]+):(?P<run_number>[0-9]+):(?P<flowcell_id>[a-zA-Z0-9]+):(?P<lane>[0-9]+):(?P<tile>[0-9]+):(?P<x_pos>[0-9]+):(?P<y_pos>[0-9]+) (?P<read>[0-9]+):(?P<is_filtered>[YN]):(?P<control_number>[0-9]+):(?P<index_sequence>[ACGT]+){0,1}')
    # Old illumina and sanger reads use /1 /2 or /f /r
    strip_one_endings = ['/1', '/2', '/3', '/f', '/r']
    
    if read_id_regex.match(seq_id):
        id_format = ILLUMINA
    elif (seq_id[-2:] in strip_one_endings):
        id_format = STRIPONE
    else:
        id_format = OTHER
    return id_format

def strip_read_from_id(seq_id, id_format=None):
    new_id = seq_id
    if not id_format:
        id_format = determine_id_format(seq_id)
    elif id_format == STRIPONE:
        new_id = seq_id[0:-1]
    elif id_format == ILLUMINA:
        new_id = seq_id.split(' ')[0]
    return new_id

def strip_read_from_id_stripone(seq_id):
    return seq_id[0:-1]

def strip_read_from_id_illumina(seq_id):
    return seq_id.split(' ')[0]

def match_id(id1, id2, id_format=OTHER):
    ''' Return true if id's match using rules for specified format '''
    if id_format == STRIPONE:
        if id1[0:-1] == id2[0:-1]:
            return True
        else:
            return False
    elif id_format == ILLUMINA:
        if (id1.split(' ')[0] == id2.split(' ')[0]):
            return True
        else:
            return False
    elif id1 == id2:
        return True
    else:
        return False

def read_fastq(filehandle):
    ''' Return dictionary with 'seq_id', 'seq', 'qual_id', and 'qual' '''
    record_line = 0
    read_number = 0
    fastq_record = dict()
    for line in filehandle:
        record_line += 1
        if record_line == 1:
            fastq_record['seq_id'] = line.strip()
        elif record_line == 2:
            fastq_record['seq'] = line.strip()
        elif record_line == 3:
            fastq_record['qual_id'] = line.strip()
        elif record_line == 4:
            record_line = 0
            fastq_record['qual'] = line.strip()
            read_number += 1
            yield fastq_record
            

def fastq_string(record):
    return "%s\n%s\n%s\n%s\n" % (record['seq_id'], record['seq'], record['qual_id'], record['qual'])

if __name__ == '__main__':
    sys.exit(main())
