#!/usr/bin/env python

import optparse, sys, os, re, time, pprint

p = optparse.OptionParser(description = """Part of the MetagenomeDB toolkit.
Imports XML-formatted NCBI BLAST alignments into the database.""")

g = optparse.OptionGroup(p, "Input")

g.add_option("-i", "--input", dest = "input_fn", metavar = "FILENAME",
	help = "XML-formatted output of a NCBI BLAST sequence alignment (mandatory).")

g.add_option("-Q", "--query-collection", dest = "queries_collection", metavar = "STRING",
	help = "Name of the collection the query sequences belong to (mandatory).")

g.add_option("-H", "--hit-collection", dest = "hits_collection", metavar = "STRING",
	help = """Name of the collection the hit sequences belong to (optional). If not
provided, the hit sequences are assumed to be external to the database, and only
a summary of those hits will be stored: hit identifier, description and E-value.""")

g.add_option("--date", dest = "date", nargs = 3, type = "int", metavar = "YEAR MONTH DAY",
	help = "Date of the BLAST run (optional). By default, creation date of the input file.")

g.add_option("--query-id-getter", dest = "query_id_getter", metavar = "PYTHON CODE", default = "%.split()[0]",
	help = "Python code to reformat query identifiers (optional); '%' will be replaced by the query identifier. Default: %default")

g.add_option("--hit-id-getter", dest = "hit_id_getter", metavar = "PYTHON CODE", default = "%.split()[0]",
	help = "Python code to reformat hit identifiers (optional); '%' will be replaced by the hit identifier. Default: %default")

g.add_option("--no-check", dest = "check", action = "store_false", default = True,
	help = "If set, bypass the query and hit sequences identifier check (not recommended).")

p.add_option_group(g)

g = optparse.OptionGroup(p, "Input filtering")

g.add_option("--max-E-value", dest = "max_e_value", type = "float", metavar = "FLOAT",
	help = "If set, filter out all hits with a E-value above the provided cut-off.")

g.add_option("--min-identity", dest = "min_identity", type = "int", metavar = "INTEGER",
	help = "If set, filter out all hits with a percent of identity below the provided cut-off.")

g.add_option("--max-hits", dest = "max_hits", type = "int", metavar = "INTEGER",
	help = "If set, keep only the first '--max-hits' hits for each query.")

g.add_option("--ignore-alignment", dest = "include_alignment", action = "store_false", default = True,
	help = "If set, will not store information about the sequence alignment (HSP coordinates and sequences).")

p.add_option_group(g)

p.add_option("-v", "--verbose", dest = "verbose", action = "store_true", default = False)
p.add_option("--dry-run", dest = "dry_run", action = "store_true", default = False)

g = optparse.OptionGroup(p, "Connection")

g.add_option("--host", dest = "connection_host", metavar = "HOSTNAME", default = "localhost",
	help = "Host name or IP address of the MongoDB server (optional). Default: %default")

g.add_option("--port", dest = "connection_port", metavar = "INTEGER", default = 27017,
	help = "Port of the MongoDB server (optional). Default: %default")

g.add_option("--db", dest = "connection_db", metavar = "STRING", default = "MetagenomeDB",
	help = "Name of the database in the MongoDB server (optional). Default: '%default'")

g.add_option("--user", dest = "connection_user", metavar = "STRING", default = '',
	help = "User for the MongoDB server connection (optional). Default: '%default'")

g.add_option("--password", dest = "connection_password", metavar = "STRING", default = '',
	help = "Password for the MongoDB server connection (optional). Default: '%default'")

p.add_option_group(g)

(p, a) = p.parse_args()

def error (msg):
	if str(msg).endswith('.'):
		msg = str(msg)[:-1]
	print >>sys.stderr, "ERROR: %s." % msg
	sys.exit(1)

if (p.input_fn == None):
	error("A XML-formatted BLAST alignment output file must be provided")

if (not os.path.exists(p.input_fn)):
	error("File '%s' not found" % p.input_fn)

if (p.queries_collection == None):
	error("A collection must be provided for query sequences")

if (not p.date):
	date = time.localtime(os.path.getmtime(p.input_fn))
	p.date = (date.tm_year, date.tm_mon, date.tm_mday)

else:
	try:
		y, m, d = p.date
		assert (y > 1990), "value '%s' is incorrect for year" % y
		assert (m > 0) and (m < 13), "value '%s' is incorrect for month" % m
		assert (d > 0) and (d < 32), "value '%s' is incorrect for day" % d

	except Exception, msg:
		error("Invalid date: %s" % msg)

try:
	get_query_id = eval("lambda x: " + p.query_id_getter.replace('%', 'x'))
	get_hit_id = eval("lambda x: " + p.hit_id_getter.replace('%', 'x'))

except SyntaxError, e:
	error("Invalid getter: %s\n%s^" % (e.text, ' ' * (e.offset + 22)))

if (p.max_e_value):
	if (p.max_e_value < 0):
		error("Invalid E-value cut-off: %s" % p.max_e_value)

if (p.min_identity):
	if (p.min_identity < 0) or (p.min_identity > 100):
		error("Invalid percent of identity cut-off: %s" % p.min_identity)

if (p.max_hits):
	if (p.max_hits < 0):
		error("Invalid number of hits cut-off: %s" % p.max_hits)

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

import MetagenomeDB as mdb

try:
	from Bio.Blast import NCBIXML
except:
	error("The BioPython library is not installed.\nTry 'easy_install biopython'")

if (p.verbose):
	mdb.max_verbosity()

if (p.connection_host or p.connection_port or p.connection_db or p.connection_user or p.connection_password):
	try:
		mdb.connect(p.connection_host, p.connection_port, p.connection_db, p.connection_user, p.connection_password)
	except Exception as msg:
		error(msg)

print "Importing '%s' ..." % p.input_fn

# Check query and hit sequences
print "  validating query and hit sequences ..."

QuerySequences, HitSequences = {}, {}
DuplicateQueries, DuplicateHits = {}, {}

try:
	queries = mdb.Collection.find_one({"name": p.queries_collection})
	if (queries == None):
		error("Unknown collection '%s'" % p.queries_collection)

	for sequence in queries.list_sequences():
		sequence_name = str(sequence["name"]) 
		if (sequence_name in QuerySequences):
			DuplicateQueries[sequence_name] = True

		QuerySequences[sequence_name] = sequence

	if (p.hits_collection):
		hits = mdb.Collection.find_one({"name": p.hits_collection})
		if (hits == None):
			error("Unknown collection '%s'" % p.hits_collection)

		for sequence in hits.list_sequences():
			sequence_name = str(sequence["name"]) 
			if (sequence_name in HitSequences):
				DuplicateHits[sequence_name] = True

			HitSequences[sequence_name] = sequence

except Exception as msg:
	error(msg)

if (p.check):
	n_records = 0
	for record in NCBIXML.parse(open(p.input_fn, 'r')):
		query_id = get_query_id(record.query)

		if (not query_id in QuerySequences):
			error("Unknown query sequence '%s'" % query_id)

		if (query_id in DuplicateQueries):
			error("Duplicate query sequence '%s'" % query_id)

		if (p.hits_collection):
			for hit in record.alignments:
				hit_id = get_hit_id(hit.title)

				if (not hit_id in HitSequences):
					error("Unknown hit sequence '%s'" % hit_id)

				if (hit_id in DuplicateHits):
					error("Duplicate hit sequence '%s'" % hit_id)

		n_records += 1

# Import
print "  importing HSPs ..."

class ProgressBar:
	def __init__ (self, upper = None):
		self.__min = 0.0
		self.__max = upper + 0.0

	def display (self, value):
		f = (value - self.__min) / (self.__max - self.__min) # fraction
		p = 100 * f # percentage
		s = int(round(80 * f)) # bar size

		sys.stdout.write(' ' * 4 + ('.' * s) + " %4.2f%%\r" % p)
		sys.stdout.flush()

	def clear (self):
		sys.stdout.write(' ' * (4 + 80 + 8) + "\r")
		sys.stdout.flush()

show_pb = p.check and (not p.dry_run)

if (show_pb):
	pb = ProgressBar(n_records)

external_hits = (p.hits_collection == None)

n = 0
for record in NCBIXML.parse(open(p.input_fn, 'r')):
	query_id = get_query_id(record.query)
	query_o = QuerySequences[query_id]

	if (external_hits):
		hits = query_o.get_property("alignments", [])

	m = 0
	for hit in record.alignments:
		hit_id = get_hit_id(hit.title)
		if (not external_hits):
			hit_o = HitSequences[hit_id]

		m += 1
		if (p.max_hits) and (m > p.max_hits):
			break

		for hsp in hit.hsps:
			identity = 100.0 * hsp.identities / hsp.align_length

			if (p.min_identity) and (identity < p.min_identity):
				continue

			if (p.max_e_value) and (hsp.expect > p.max_e_value):
				continue

			# documentation:
			# - ftp://ftp.ncbi.nlm.nih.gov/blast/documents/xml/README.blxml for information about the NCBI BLAST XML format
			# - http://www.biopython.org/DIST/docs/api/Bio.Blast.NCBIXML-pysrc.html for information about how the XML is parsed by BioPython
			# - http://www.biopython.org/DIST/docs/api/Bio.Blast.Record-pysrc.html for information about how the result is stored as a Record
			r = {
				"type": "similar-to",

				"run": {
					"date": {"year": p.date[0], "month": p.date[1], "day": p.date[2]},
					"algorithm": {
						"name": record.application,
						"version": record.version,
						"parameters": {
							"expect": float(record.expect),
							"matrix": record.matrix,
							"gap_open": record.gap_penalties[0],
							"gap_extend": record.gap_penalties[1],
							"sc_match": record.sc_match,
							"sc_mismatch": record.sc_mismatch,
							"filter": record.filter
						},
					},
					"database": {
						"name": record.database,
						"number_of_sequences": record.database_sequences,
						"number_of_letters": record.num_letters_in_database,
					}
				},

				"score": {
					"fraction_identical": identity,
					"fraction_conserved": 100.0 * hsp.positives / hsp.align_length,
					"e_value": hsp.expect,
					"gaps": hsp.gaps,
				}
			}

			if (p.include_alignment):
				r["alignment"] = {
					"source_coordinates": (hsp.query_start, hsp.query_end),
					"source": hsp.query,
					"match": hsp.match,
					"target": hsp.sbjct,
					"target_coordinates": (hsp.sbjct_start, hsp.sbjct_end),
				}

			# the hit should be in the database. In this case, we store the HSP
			# as properties of a relationship between query and hit sequences.
			if (not external_hits):
				if (p.dry_run):
					print "    query '%s' to hit '%s'" % (query_id, hit_id)
					for line in pprint.pformat(r).split('\n'):
						print "      %s" % line
				else:
					query_o.relate_to_sequence(hit_o, r)
					query_o.commit()

			# the hit is not in the database. In this case, we store the HSP as
			# a property of the query sequence.
			else:
				r["hit" ] = {
					"name": hit_id,
					"description": hit.hit_def,
					"length": hit.length
				}

				if (p.dry_run):
					print "    query '%s' to external hit '%s'" % (query_id, hit_id)
					for line in pprint.pformat(r).split('\n'):
						print "      %s" % line
				else:
					hits.append(r)

	if (external_hits) and (not p.dry_run):
		query_o["alignments"] = hits
		query_o.commit()

	if (show_pb):
		pb.display(n)

	n += 1

if (show_pb):
	pb.clear()

print "    done."

if (p.dry_run):
	print "(dry run)"
