#!/usr/bin/env python
import sys

import pysam
import concurrent.futures
import multiprocessing
import time
import operator
import itertools
from fpkem.bitset import BitSet
from fpkem.intersection import IntervalTree
import os
import argparse
import gzip
import pickle
import traceback

print_lock = multiprocessing.RLock()

class ElapsedTimeReporter:	
	__print_lock = print_lock

	def __init__(self, message, same_line = True, seconds_format = '{0:.2f} s', failed_message = 'failed'):
		self.message = message
		self.same_line = same_line
		self.seconds_format = seconds_format
		self.failed_message = failed_message

	def __enter__(self):
		if self.message:
			with ElapsedTimeReporter.__print_lock:
				if self.same_line:
					print(self.message, end='', flush=True)
				else:
					print(self.message, flush=True)
		self.time_counter = time.perf_counter()
		return self.time_counter

	def __exit__(self, exception_type, exception_value, tb):
		if exception_type and self.failed_message:
			with ElapsedTimeReporter.__print_lock:
				print(self.failed_message, flush=True)
				traceback.print_exception(exception_type, exception_value, tb, file=sys.stderr)
				
		elif self.seconds_format:
			elapsed = time.perf_counter() - self.time_counter
			with ElapsedTimeReporter.__print_lock:
				print(self.seconds_format.format(elapsed, int(elapsed / 60), elapsed % 60), flush=True)

class Range:
	def __init__(self, start, end):
		self.start = start
		self.end = end
	def __len__(self):
		return self.end - self.start
	def __eq__(self, other):
		return self.start == other.start and self.end == other.end
	def __ne__(self, other):
		return self.start != other.start or self.end != other.end
	def __gt__(self, other):
		return self.start > other.start or (self.start == other.start and self.end > other.end)
	def __ge__(self, other):
		return self.start >= other.start
	def __lt__(self, other):
		return self.start < other.start or (self.start == other.start and self.end < other.end)
	def __le__(self, other):
		return self.start <= other.start
	def __str__(self):
		return '({0}, {1})'.format(self.start, self.end)

class Exon(Range):
	def __init__(self, gene, transcript_id, start, end):
		Range.__init__(self, start, end)		
		self.gene = gene
		self.transcript_id = transcript_id

class Gene:
	def __init__(self, chromosome, strand, gene_id, gene_name):
		self.chromosome = chromosome
		self.strand = strand
		self.gene_id = gene_id
		self.gene_name = gene_name
		self.exons = []
		
	def add_exon(self, transcript_id, start, end):
		self.exons.append(Exon(self, transcript_id, start, end))

	def start(self):
		return min(map(lambda e: e.start, self.exons))

	def end(self):
		return max(map(lambda e: e.end, self.exons))

	def __len__(self):
		return sum(map(len, self.exons))

def open_gzip_or_raw(filename, mode = 'rt'):
	if 'w' not in mode and 'a' not in mode and '+' not in mode:	
		try: 
			with open(filename, 'rb', buffering = 0) as file:
				magic = file.read(2)
				if magic != b'\x1f\x8b': raise IOError('Not a gzipped file')
			return gzip.open(filename, mode)
		except:
			pass

	return open(filename, mode)

def load_gtf(gtf):
	genes = {}	

	with open_gzip_or_raw(gtf, 'rt') as gtf:
		for line in gtf:
			if line.startswith('#'): continue
			fields = line.rstrip().split('\t')
			if len(fields) < 9 or fields[2] != 'exon': continue

			annotations = dict([ tuple(str.strip('"') for str in f.strip(' ').split(' ', 1)) for f in fields[8].rstrip(';').split(';') ])

			gene_id = annotations['gene_id']			
			if gene_id not in genes:
				chromosome = fields[0]
				strand = int(fields[6] != '+')
				gene_name = annotations.get('gene_name', gene_id)
				gene = Gene(chromosome, strand, gene_id, gene_name)
				genes[gene_id] = gene
			else:
				gene = genes[gene_id]

			ex_start = int(fields[3])-1
			ex_end = int(fields[4])
			gene.add_exon(annotations['transcript_id'], ex_start, ex_end)

		return list(genes.values())

def load_genes(gtf):
	pickle_filename = gtf + '.pickled'
	try:
		with open(pickle_filename, 'rb') as file:
			unpickler = pickle.Unpickler(file)
			ngenes = unpickler.load()
			genes = []
			for i in range(ngenes):
				(chromosome, strand, gene_id, gene_name, exons) = unpickler.load(p)
				gene = Gene(chromosome, strand, gene_id, gene_name)
				for (transcript_id, start, end) in exons:
					gene.add_exon(transcript_id, start, end)
				genes.append(gene)
	except:
		genes = load_gtf(gtf)		
		with open(pickle_filename, 'wb') as file:
			pickler = pickle.Pickler(file)
			pickler.dump(len(genes))
			for gene in genes:
				pickler.dump((gene.chromosome, gene.strand, gene.gene_id, gene.gene_name, [ (exon.transcript_id, exon.start, exon.end) for exon in gene.exons ]))

	return genes

def interval_iter(bits, start=0, end=None):
	if end == None: end = bits.size
	iend = start
	while iend < end:
		istart = bits.next_set(iend)
		if istart >= end: break
		iend = bits.next_clear(istart)
		yield (istart, min(iend, end))

def bitsets2intervals(bitsets):
	return tuple(tuple(tuple(interval_iter(strand_bitset)) for strand_bitset in chr_bitsets) for chr_bitsets in bitsets)

def add_intervals2bitsets(bitsets, intervals):
	for (chr_idx, chr_intervals) in enumerate(intervals):
		for (start, end) in chr_intervals[0]: bitsets[chr_idx][0].set_range(start, end-start)
		for (start, end) in chr_intervals[1]: bitsets[chr_idx][1].set_range(start, end-start)

def intervals2bitsets(intervals, chr_lengths):
	bitsets = tuple((BitSet(ln+1024), BitSet(ln+1024)) for ln in chr_lengths)
	add_intervals2bitsets(bitsets, intervals)
	return bitsets

def cigar_matches(aligned_read):
	start = aligned_read.pos
	for (operation, length) in aligned_read.cigar:
		if operation == 4 or operation == 5 or operation == 1: # Ignore soft/hard clipping and insertions - do not shift the position
			continue
		elif operation == 0:
			yield (start, length)
		start += length

def get_bam_coverage(bam_filename, reverse, pickleSuffix):
	pickle_filename = bam_filename + pickleSuffix
	bitsets = []
	chr_names = []
	chr_lengths = []

	with pysam.Samfile(bam_filename, "rb") as bam:
		for ref in bam.header['SQ']:
			ln = ref['LN']
			sn = ref['SN']
			bitsets.append((BitSet(ln+1024), BitSet(ln+1024)))
			chr_names.append(sn)
			chr_lengths.append(ln)

		with ElapsedTimeReporter('Calculating covered bases in {}...'.format(bam_filename), False, 'Time spent on analyzing {}: {{:.0f}} s'.format(bam_filename), 'Failed!') as time_counter:
			read_idx = 0
			for read in bam.fetch(until_eof = True):
				if not read.is_unmapped:
					bitset = bitsets[read.tid][read.is_reverse != reverse]
					for match in cigar_matches(read):
						if match[0] < bitset.size and match[1] < bitset.size:
							bitset.set_range(match[0], match[1])
						
				read_idx += 1
				if read_idx % 10000000 == 0:
					with print_lock:
						print("{0:12d} reads in {1:.2f}s ({2})".format(read_idx, time.perf_counter() - time_counter, bam_filename), flush=True)

			with print_lock:
				print("{0:12d} reads in {1:.2f}s ({2})".format(read_idx, time.perf_counter() - time_counter, bam_filename), flush=True)
			result = (chr_names, chr_lengths, bitsets, bam_filename, read_idx)
			with gzip.open(pickle_filename, 'wb') as f: pickle.dump(result, f)
			return (bam_filename, pickle_filename)
			#return (chr_names, chr_lengths, bitsets2intervals(bitsets), bam_filename, read_idx)

def get_pooled_coverage(bam_filenames, executor, reverse):
	chr_names = None
	chr_lengths = None
	bitsets = None
	read_counts = dict()

	bam_filenames_to_measure = []
	pickleSuffix = '.coverageData.pickled.gz'

	def processPickle(bam_filename, pickle_filename, chr_names, chr_lengths, bitsets, read_counts):
		with gzip.open(pickle_filename, 'rb') as f:
			(bam_chr_names, bam_chr_lengths, bam_bitsets, bam_filename, read_count) = pickle.load(f)
			if not bitsets:
				chr_names = bam_chr_names
				chr_lengths = bam_chr_lengths
				bitsets = bam_bitsets
			else:
				if chr_names != bam_chr_names or chr_lengths != bam_chr_lengths:
					raise Exception('Bam files are based on different chromosome sets')
				for i in range(len(chr_names)):
					bitsets[i][0].ior(bam_bitsets[i][0])
					bitsets[i][1].ior(bam_bitsets[i][1])
			read_counts[bam_filename] = read_count
			return (chr_names, chr_lengths, bitsets, read_counts)

	with ElapsedTimeReporter('Looking for pre-calculated coverage files...', False, 'Pre-calculated coverage loading time: {:.0f} s', None):
		for bam_filename in bam_filenames:
			pickle_filename = bam_filename + pickleSuffix
			if os.path.exists(pickle_filename) and os.path.getmtime(bam_filename) < os.path.getmtime(pickle_filename):
				with ElapsedTimeReporter('Loading {} '.format(pickle_filename), failed_message='failed!'):
					chr_names, chr_lengths, bitsets, read_counts = processPickle(bam_filename, pickle_filename, chr_names, chr_lengths, bitsets, read_counts)
			else:
				bam_filenames_to_measure.append(bam_filename)
	
	with ElapsedTimeReporter('Calculating pooled coverage...', False, 'Total coverage calculation time: {:.0f} s', 'Failed!'):
		futures = [ executor.submit(get_bam_coverage, bam_filename, reverse, pickleSuffix) for bam_filename in bam_filenames_to_measure ]
		for future in concurrent.futures.as_completed(futures):			
			bam_filename, pickle_filename = future.result()
			with ElapsedTimeReporter('Integrating data from {}...'.format(pickle_filename)):
				chr_names, chr_lengths, bitsets, read_counts = processPickle(bam_filename, pickle_filename, chr_names, chr_lengths, bitsets, read_counts)

	return (chr_names, chr_lengths, bitsets2intervals(bitsets), read_counts)

def filter_gene_by_bitset(gene, bitset, min_exon_coverage):
	gene_start = gene.start()
	gene_bits = BitSet(gene.end()-gene_start)

	covered = False
	for exon in gene.exons:
		ranges = [ Range(start, end) for (start, end) in interval_iter(bitset, exon.start, exon.end) ]
		if sum(map(len, ranges)) >= len(exon) * min_exon_coverage:
			covered = True
			for r in ranges:
				gene_bits.set_range(r.start-gene_start, r.end-r.start)

	if covered:
		new_gene = Gene(gene.chromosome, gene.strand, gene.gene_id, gene.gene_name)
		for (start, end) in interval_iter(gene_bits):
			new_gene.add_exon(None, start+gene_start, end+gene_start)
		return new_gene
	return None

def filter_genes_by_coverage_intervals(genes, coverage_intervals, min_exon_coverage, coverage_intervals_from_opposite_strand = None):
	new_genes = []

	bitset = BitSet(max(0 if not coverage_intervals else coverage_intervals[-1][1],

				      max(gene.end() for gene in genes),
				      0 if not coverage_intervals_from_opposite_strand else coverage_intervals_from_opposite_strand[-1][1]))
	for (s, e) in coverage_intervals: bitset.set_range(s, e-s)
	if coverage_intervals_from_opposite_strand:
		for (s, e) in coverage_intervals_from_opposite_strand: bitset.set_range(s, e-s)

	idx = 0
	for gene in genes:
		new_gene = filter_gene_by_bitset(gene, bitset, min_exon_coverage)
		if new_gene:
			new_genes.append(new_gene)
		idx += 1
		if idx % 100 == 0:
			with print_lock:
				print('Chromosome {0}({1}): {2:5d} genes processsed ({3:.0%})'.format(genes[0].chromosome, '+-'[genes[0].strand], idx, idx/len(genes)), flush=True)
	with print_lock:
		print('Chromosome {0}({1}): {2:5d} genes processsed (100%)'.format(genes[0].chromosome, '+-'[genes[0].strand], idx), flush=True)

	return new_genes			

def filter_genes(genes, chr_names, genome_coverage_intervals, executor, min_exon_coverage, merge_both_strand_intervals):
	split_genes = tuple(([],[]) for chr_name in chr_names)
	new_genes = []

	with ElapsedTimeReporter('Sorting genes... '):
		for gene in genes:
			if gene.chromosome in chr_names:
				split_genes[chr_names.index(gene.chromosome)][gene.strand].append(gene)

	with ElapsedTimeReporter('Processing genes on individual strands... ', False, 'Total elapsed time: {0:.0f} s', 'Failed!'):
		futures = []
		for i in range(len(chr_names)):
			for strand in (0, 1):
				g = split_genes[i][strand]
				ci = genome_coverage_intervals[i][strand]
				ci_rev = None
				if merge_both_strand_intervals:
					ci_rev = genome_coverage_intervals[i][1 - strand]
				if g and (ci or ci_rev):
					futures.append(executor.submit(filter_genes_by_coverage_intervals, g, ci, min_exon_coverage, ci_rev))

		for future in concurrent.futures.as_completed(futures):			
			new_genes.extend(future.result())

	return new_genes

def get_counts_from_bam(bam_filename, genes, chr_names, reverse, ignore_strand, expected_read_count):
	trees = [ (IntervalTree(), IntervalTree()) for c in chr_names ]
	for (gene_idx, gene) in enumerate(genes):
		tree = trees[chr_names.index(gene.chromosome)][gene.strand]
		for exon in gene.exons:
			tree.insert(exon.start, exon.end, gene_idx)

	with pysam.Samfile(bam_filename, "rb") as bam:		
		with ElapsedTimeReporter('Calculating read counts from {}'.format(bam_filename), False, 'Time spent on analyzing {}: {{:.0f}} s'.format(bam_filename), 'Failed!') as time_counter:
			counts = [0] * len(genes)
			counted = 0
			skipped = 0
			ambiguous = 0
			unmapped = 0
			total = 0
			for aligned_read in bam.fetch(until_eof = True):
				if aligned_read.is_unmapped:
					unmapped += 1
				else:
					tree = trees[aligned_read.tid][aligned_read.is_reverse != reverse]
					gene_idxs = [ gene_idx for (match_start, match_len) in cigar_matches(aligned_read) for gene_idx in tree.find(match_start, match_start+match_len) ]
					if ignore_strand:
						tree = trees[aligned_read.tid][aligned_read.is_reverse == reverse]
						gene_idxs += [ gene_idx for (match_start, match_len) in cigar_matches(aligned_read) for gene_idx in tree.find(match_start, match_start+match_len) ]
					if gene_idxs:
						gene_idx = gene_idxs[0]
						if not any(gene_idx != idx for idx in gene_idxs):
							counts[gene_idx] += 1
							counted += 1
						else: ambiguous += 1
					else: skipped += 1


				total += 1
				if total % 10000000 == 0:
					with print_lock:
						print("{0:12d} reads ({1:.0%}) in {2:.2f}s ({3})".format(total, total/expected_read_count, time.perf_counter() - time_counter, bam_filename), flush=True)
			with print_lock:
				print("{0:12d} reads (100%) in {1:.2f}s ({2})".format(total, time.perf_counter() - time_counter, bam_filename), flush=True)
				print('Final stats from {}:'.format(bam_filename), flush=True)
				print('{0:12d} counted reads ({1:.2%})'.format(counted, counted/total), flush=True)
				print('{0:12d} skipped reads ({1:.2%})'.format(skipped, skipped/total), flush=True)
				print('{0:12d} ambiguous reads ({1:.2%})'.format(ambiguous, ambiguous/total), flush=True)
				print('{0:12d} unmapped reads ({1:.2%})'.format(unmapped, unmapped/total), flush=True)


			return (counts, (counted, skipped, ambiguous, unmapped))


def get_counts(executor, bam_filenames, genes, chr_names, reverse, ignore_strand, expected_read_counts):
	counts = []
	stats = []
	with ElapsedTimeReporter('Calculating gene coverage...', False, 'Total elapsed time: {0:.0f} s', 'Failed!'):
		futures = [ executor.submit(get_counts_from_bam, bam_filename, genes, chr_names, reverse, ignore_strand, expected_read_counts[bam_filename]) for bam_filename in bam_filenames ]
		for f in futures:
			(bam_counts, bam_stats) = f.result()
			stats.append(bam_stats)
			counts.append(bam_counts)

		return (counts, stats)

def merge_results_by_sample(samples, counts_per_bam, stats_per_bam):
	with ElapsedTimeReporter('Merging results by sample... '):
		sample_names = sorted(set(samples))

		bam_counts_by_sample = dict((s, []) for s in sample_names)
		stats_by_sample = dict((s, []) for s in sample_names)
		for (sample, c, s) in zip(samples, counts_per_bam, stats_per_bam):
			bam_counts_by_sample[sample].append(c)
			stats_by_sample[sample].append(s)

		stats_per_sample = [ tuple(map(sum, zip(*stats_by_sample[sample]))) for sample in sample_names ]	
		counts_per_sample =  [ tuple(map(sum, zip(*bam_counts_by_sample[sample]))) for sample in sample_names ]
	
		return (sample_names, stats_per_sample, counts_per_sample)

def fpkem(bam_filenames, samples, gtf_filename, genome_coverage_bed_filename, filtered_genes_bed_filename, counts_filename, fpkem_filename, stats_filename, reverse, ignore_strand, min_exon_coverage, np):
	with concurrent.futures.ProcessPoolExecutor(np) as executor, open(filtered_genes_bed_filename, 'wt') as filtered_genes_bed, open(counts_filename, 'wt') as counts_file, open(fpkem_filename, 'wt') as fpkem_file, open(stats_filename, 'wt') as stats_file:

		(chr_names, chr_lengths, intervals, read_counts) = get_pooled_coverage(bam_filenames, executor, reverse)

		with ElapsedTimeReporter('Loading genes... '):
			genes = load_genes(gtf_filename)

		filtered_genes = filter_genes(genes, chr_names, intervals, executor, min_exon_coverage, merge_both_strand_intervals = ignore_strand)
		(counts_per_bam, stats_per_bam) = get_counts(executor, bam_filenames, filtered_genes, chr_names, reverse, ignore_strand, read_counts)

		(sample_names, stats_per_sample, counts_per_sample) = merge_results_by_sample(samples, counts_per_bam, stats_per_bam)

		with ElapsedTimeReporter('Writing output files... '):
			print('Sample\tCounted\tSkipped\tAmbiguous\tUnmapped', file=stats_file)
			for (sample_name, stats) in zip(sample_names, stats_per_sample):
				print('{0}\t{1}'.format(sample_name, '\t'.join(map(str, stats))), file=stats_file)

			counted_reads_per_sample = tuple(map(operator.itemgetter(0), stats_per_sample))
			header = '\t'.join(['Gene ID', 'Gene Name'] + sample_names)
			print(header, file=counts_file)
			print(header, file=fpkem_file)
			for gene, *counts in zip(filtered_genes, *counts_per_sample):
				exons = gene.exons
				gene_length = sum(map(len, exons))
				fpkem = [ c/(t*gene_length/1e9) for (c, t) in zip(counts, counted_reads_per_sample) ]
				print('\t'.join([gene.gene_id, gene.gene_name] + list(map(str, counts))), file=counts_file)
				print('\t'.join([gene.gene_id, gene.gene_name] + list(map(str, fpkem))), file=fpkem_file)
				print('\t'.join([
					gene.chromosome,
					str(exons[0].start),
					str(exons[-1].end),
					gene.gene_id,
					str(gene_length),
					'+-'[gene.strand],
					str(exons[0].start),
					str(exons[-1].end),
					'0',
					str(len(exons)),
					','.join(str(len(exon)) for exon in exons),
					','.join(str(exon.start-exons[0].start) for exon in exons)]), file=filtered_genes_bed)



def test():
	files = []
	samples = []
	with open('/data/maciejp/Aylin/RNASeq_VariantCalling/samples.txt', 'rt') as infile:
        	line_iter = iter(infile)
	        line = next(line_iter)
        	for line in line_iter:
        	        fields = line.rstrip().split('\t')
        	        samples.append(fields[0])
        	        files.append('/data/maciejp/Aylin/RNASeq_VariantCalling/' + fields[1].replace('.fastq.gz', '.STAR.bam'))


	fpkem(files, samples, '/data/db/Homo_sapiens/Ensembl/GRCh37.73/Homo_sapiens.GRCh37.73.selected.gtf', 'filtered_genes.bed', 'counts.txt', 'fpkem.txt', 'stats.txt')

def prepare_argparser():
	description = "Estimation of gene expression based on RNA-Seq data"
	argparser = argparse.ArgumentParser(description=description)
	argparser.add_argument('-o', "--output-prefix", dest = "output_prefix", required = False, default = 'FPKEMOutput_',
		help = 'Prefix for generated output files. Default: FPKEMOutput_', metavar = "PREFIX" )
	argparser.add_argument('--library-type', dest = "library_type", required = False, choices = ('first-strand', 'second-strand', 'non-stranded'), default = 'first-strand',
		help = 'Library type ("first-strand", "second-strand", "non-stranded"). Default: first-strand"', metavar = 'LIBRARY_TYPE')
	argparser.add_argument('-mc', '--minimum-coverage', dest = 'minimum_coverage', type = float, required = False, default = 0.8,
		help = "Minimum exon coverage. Default: 0.8", metavar = "FRACTION")
	argparser.add_argument('-p', '--processes', dest = "number_of_processes", type = int, required = False, default = 1,
		help = "Maximum number of processes that can be used. Default: 1", metavar = "CPU_COUNT" )
	argparser.add_argument(dest = "gtf_filename",
		help = "GTF file with gene annotations (must contain gene_id and gene_name fields).", metavar = "GTF_FILE" )
	argparser.add_argument(dest = "bam_filenames", type = lambda s: s.split(','),
		help = "A comma-separated list of BAM files", metavar = "BAM_FILES" )
	argparser.add_argument(dest = "samples", type = lambda s: s.split(','),
		help = "A comma-separated list of sample names (one for each BAM file).", metavar = "SAMPLE_NAMES" )
	return argparser

def main():
	argparser = prepare_argparser()
	args = argparser.parse_args()

	if len(args.bam_filenames) != len(args.samples):
		argparser.error("The number of sample names differ from the number of BAM files.")
	for filename in args.bam_filenames:
		if not os.path.exists(filename): argparser.error("BAM file {} not found.".format(filename))
	if not os.path.exists(args.gtf_filename):
		argparser.error("GTF file {} not found.".format(args.gtf_filename))
	if args.minimum_coverage > 1 or args.minimum_coverage < 0:
		argparser.error("The minimum coverage parameter is outside the [0, 1] range.")
	if args.number_of_processes < 1:
		argparser.error("The number of processes cannot be less than 1.")

	reverse = args.library_type == 'first-strand'
	ignore_strand = args.library_type == 'non-stranded'
	fpkem(args.bam_filenames, args.samples, args.gtf_filename,
		args.output_prefix + "GenomeCoverage.bed",
		args.output_prefix + "FilteredGenes.bed",
		args.output_prefix + "Counts.txt",
		args.output_prefix + "FPKEM.txt",
		args.output_prefix + "Stats.txt",
		reverse, ignore_strand, 
		args.minimum_coverage, args.number_of_processes)

if __name__ == "__main__":
	main()
