#!/usr/bin/env python

import optparse, sys, os, re, time, pprint
import MetagenomeDB as mdb

p = optparse.OptionParser(description = """Part of the MetagenomeDB toolkit.
Imports XML-formatted NCBI BLAST alignments into the database.""")

g = optparse.OptionGroup(p, "Input")

g.add_option("-i", "--input", dest = "input_fn", metavar = "FILENAME",
	help = "XML-formatted output of a NCBI BLAST sequence alignment (mandatory).")

g.add_option("-Q", "--query-collection", dest = "queries_collection", metavar = "STRING",
	help = "Name of the collection the query sequences belong to (mandatory).")

g.add_option("-H", "--hit-collection", dest = "hits_collection", metavar = "STRING",
	help = """Name of the collection the hit sequences belong to (optional). If not
provided, the hit sequences are assumed to be external to the database, and only
a summary of those hits will be stored: hit identifier, description and E-value.""")

g.add_option("--date", dest = "date", nargs = 3, type = "int", metavar = "YEAR MONTH DAY",
	help = "Date of the BLAST run (optional). By default, creation date of the input file.")

g.add_option("--query-id-getter", dest = "query_id_getter", metavar = "PYTHON CODE", default = "%.split()[0]",
	help = """Python code to reformat query identifiers (optional); '%' will be
replaced by the query identifier. Default: %default""")

g.add_option("--hit-id-getter", dest = "hit_id_getter", metavar = "PYTHON CODE", default = "%.split()[0]",
	help = """Python code to reformat hit identifiers (optional); '%' will be
replaced by the hit identifier. Default: %default""")

p.add_option_group(g)

g = optparse.OptionGroup(p, "Filtering")

g.add_option("--max-E-value", dest = "max_e_value", type = "float", metavar = "FLOAT",
	help = "If set, filter out all hits with a E-value above the provided cut-off.")

g.add_option("--min-identity", dest = "min_identity", type = "int", metavar = "INTEGER",
	help = "If set, filter out all hits with a percent of identity below the provided cut-off.")

g.add_option("--max-hits", dest = "max_hits", type = "int", metavar = "INTEGER",
	help = "If set, keep only the first '--max-hits' hits for each query.")

g.add_option("--ignore-alignment", dest = "include_alignment", action = "store_false", default = True,
	help = "If set, will not store HSP sequences and conservation lines.")

p.add_option_group(g)

g = optparse.OptionGroup(p, "Errors handling")

g.add_option("--ignore-large-entries", dest = "ignore_large_entries", action = "store_true", default = False,
	help = """If set, ignore cases where a large amount of hits being associated
to a given query would result in this query object to be too large for the database.""")

p.add_option_group(g)

p.add_option("-v", "--verbose", dest = "verbose", action = "store_true", default = False)
p.add_option("--no-progress-bar", dest = "display_progress_bar", action = "store_false", default = True)
p.add_option("--dry-run", dest = "dry_run", action = "store_true", default = False)
p.add_option("--version", dest = "display_version", action = "store_true", default = False)

g = optparse.OptionGroup(p, "Connection")

connection_parameters = {}
def declare_connection_parameter (option, opt, value, parser):
	connection_parameters[opt[2:]] = value

g.add_option("--host", dest = "connection_host", metavar = "HOSTNAME",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Host name or IP address of the MongoDB server (optional). Default:
'host' property in ~/.MetagenomeDB, or 'localhost' if not found.""")

g.add_option("--port", dest = "connection_port", metavar = "INTEGER",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Port of the MongoDB server (optional). Default: 'port' property
in ~/.MetagenomeDB, or 27017 if not found.""")

g.add_option("--db", dest = "connection_db", metavar = "STRING",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Name of the database in the MongoDB server (optional). Default:
'db' property in ~/.MetagenomeDB, or 'MetagenomeDB' if not found.""")

g.add_option("--user", dest = "connection_user", metavar = "STRING",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """User for the MongoDB server connection (optional). Default:
'user' property in ~/.MetagenomeDB, or none if not found.""")

g.add_option("--password", dest = "connection_password", metavar = "STRING",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Password for the MongoDB server connection (optional). Default:
'password' property in ~/.MetagenomeDB, or none if not found.""")

p.add_option_group(g)

(p, a) = p.parse_args()

def error (msg):
	msg = str(msg)
	if msg.endswith('.'):
		msg = msg[:-1]
	print >>sys.stderr, "ERROR: %s." % msg
	sys.exit(1)

if (p.display_version):
	print mdb.version
	sys.exit(0)

if (p.input_fn == None):
	error("A XML-formatted BLAST alignment output file must be provided")

if (not os.path.exists(p.input_fn)):
	error("File '%s' not found" % p.input_fn)

if (p.queries_collection == None):
	error("A collection must be provided for query sequences")

if (not p.date):
	date = time.localtime(os.path.getmtime(p.input_fn))
	p.date = (date.tm_year, date.tm_mon, date.tm_mday)

else:
	try:
		y, m, d = p.date
		assert (y > 1990), "value '%s' is incorrect for year" % y
		assert (m > 0) and (m < 13), "value '%s' is incorrect for month" % m
		assert (d > 0) and (d < 32), "value '%s' is incorrect for day" % d

	except Exception, msg:
		error("Invalid date: %s" % msg)

try:
	get_query_id = eval("lambda x: " + p.query_id_getter.replace('%', 'x').replace("\\x", '%'))
	get_hit_id = eval("lambda x: " + p.hit_id_getter.replace('%', 'x').replace("\\x", '%'))

except SyntaxError as e:
	error("Invalid getter: %s\n%s^" % (e.text, ' ' * (e.offset + 22)))

if (p.max_e_value):
	if (p.max_e_value < 0):
		error("Invalid E-value cut-off: %s" % p.max_e_value)

if (p.min_identity):
	if (p.min_identity < 0) or (p.min_identity > 100):
		error("Invalid percent of identity cut-off: %s" % p.min_identity)

if (p.max_hits):
	if (p.max_hits < 0):
		error("Invalid number of hits cut-off: %s" % p.max_hits)

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

try:
	from Bio.Blast import NCBIXML
except:
	error("The BioPython library is not installed.\nTry 'easy_install biopython'")

if (p.verbose):
	mdb.max_verbosity()

try:
	mdb.connect(**connection_parameters)
except Exception as msg:
	error(msg)

try:
	queries = mdb.Collection.find_one({"name": p.queries_collection})
	if (queries == None):
		error("Unknown queries collection '%s'" % p.queries_collection)

	if (p.hits_collection):
		hits = mdb.Collection.find_one({"name": p.hits_collection})
		if (hits == None):
			error("Unknown hits collection '%s'" % p.hits_collection)

except mdb.errors.DBConnectionError as msg:
	error(msg)

print "Importing '%s' ..." % p.input_fn

print "  validating query and hit sequences ..."

n = 0
for record in NCBIXML.parse(open(p.input_fn, 'r')):
	query_id = get_query_id(record.query)

	candidates = list(queries.list_sequences({"name": query_id}))

	if (len(candidates) == 0):
		error("Unknown query sequence '%s'" % query_id)

	if (len(candidates) > 1):
		error("Duplicate query sequence '%s'" % query_id)

	if (p.hits_collection):
		for hit in record.alignments:
			hit_id = get_hit_id(hit.title)
			candidates = list(hits.list_sequences({"name": hit_id}))

			if (len(candidates) == 0):
				error("Unknown hit sequence '%s'" % hit_id)

			if (len(candidates) > 1):
				error("Duplicate hit sequence '%s'" % hit_id)

	n += 1

if (n == 0):
	error("No BLAST hit in the input")

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

print "  importing HSPs ..."

class ProgressBar:
	def __init__ (self, upper = None):
		self.__min = 0.0
		self.__max = upper + 0.0

	def display (self, value):
		f = (value - self.__min) / (self.__max - self.__min) # fraction
		p = 100 * f # percentage
		s = int(round(80 * f)) # bar size

		sys.stdout.write(' ' * 4 + ('.' * s) + " %4.2f%%\r" % p)
		sys.stdout.flush()

	def clear (self):
		sys.stdout.write(' ' * (4 + 80 + 8) + "\r")
		sys.stdout.flush()

pb = ProgressBar(n)
n = 0

def get_sequence (collection, name):
	return list(collection.list_sequences({"name": name}))[0]

external_hits = (p.hits_collection == None)

def try_commit (query_o, query_id):
	try:
		query_o.commit()

	except mdb.errors.DBOperationError as msg:
		if ("too large" in str(msg)):
			if (p.ignore_large_entries):
				print >>sys.stderr, "WARNING: Too many hits for query '%s'; this information will be ignored." % query_id
				return
			else:
				error("Too many hits for query '%s'" % query_id)
		else:
			error(msg)

for record in NCBIXML.parse(open(p.input_fn, 'r')):
	query_id = get_query_id(record.query)
	query_o = get_sequence(queries, query_id)

	if (external_hits):
		hits = query_o.get_property("alignments", [])

	m = 0
	for hit in record.alignments:
		hit_id = get_hit_id(hit.title)
		if (not external_hits):
			hit_o = get_sequence(hits, hit_id)

		m += 1
		if (p.max_hits) and (m > p.max_hits):
			break

		for hsp in hit.hsps:
			identity = 100.0 * hsp.identities / hsp.align_length

			if (p.min_identity) and (identity < p.min_identity):
				continue

			if (p.max_e_value) and (hsp.expect > p.max_e_value):
				continue

			# documentation:
			# - ftp://ftp.ncbi.nlm.nih.gov/blast/documents/xml/README.blxml for information about the NCBI BLAST XML format
			# - http://www.biopython.org/DIST/docs/api/Bio.Blast.NCBIXML-pysrc.html for information about how the XML is parsed by BioPython
			# - http://www.biopython.org/DIST/docs/api/Bio.Blast.Record-pysrc.html for information about how the result is stored as a Record
			r = {
				"type": "similar-to",
				"run": {
					"date": {"year": p.date[0], "month": p.date[1], "day": p.date[2]},
					"algorithm": {
						"name": record.application,
						"version": record.version,
						"parameters": {
							"expect": float(record.expect),
							"matrix": record.matrix,
							"gap_open": record.gap_penalties[0],
							"gap_extend": record.gap_penalties[1],
							"sc_match": record.sc_match,
							"sc_mismatch": record.sc_mismatch,
							"filter": record.filter
						},
					},
					"database": {
						"name": record.database,
						"number_of_sequences": record.database_sequences,
						"number_of_letters": record.num_letters_in_database,
					}
				},
				"score": {
					"percent_identity": identity,
					"percent_positives": 100.0 * hsp.positives / hsp.align_length,
					"e_value": hsp.expect,
					"gaps": hsp.gaps,
				},
				"alignment": {
					"source_coordinates": (hsp.query_start, hsp.query_end),
					"target_coordinates": (hsp.sbjct_start, hsp.sbjct_end),
				},
			}

			if (p.include_alignment):
				r["alignment"]["source"] = hsp.query
				r["alignment"]["match"] = hsp.match
				r["alignment"]["target"] = hsp.sbjct

			# the hit should be in the database. In this case, we store the HSP
			# as properties of a relationship between query and hit sequences.
			if (not external_hits):
				if (p.dry_run):
					print "    query '%s' to hit '%s'" % (query_id, hit_id)
					for line in pprint.pformat(r).split('\n'):
						print "      %s" % line
				else:
					query_o.relate_to_sequence(hit_o, r)
					try_commit(query_o, query_id)

			# the hit is not in the database. In this case, we store the HSP as
			# a property of the query sequence.
			else:
				r["hit" ] = {
					"name": hit_id,
					"description": hit.hit_def,
					"length": hit.length
				}

				if (p.dry_run):
					print "    query '%s' to external hit '%s'" % (query_id, hit_id)
					for line in pprint.pformat(r).split('\n'):
						print "      %s" % line
				else:
					hits.append(r)

	if (external_hits) and (not p.dry_run) and (len(hits) > 0):
		query_o["alignments"] = hits
		try_commit(query_o, query_id)

	if (not p.dry_run) and (p.display_progress_bar):
		pb.display(n)

	n += 1

if (not p.dry_run) and (p.display_progress_bar):
	pb.clear()

print "    %s quer%s imported." % (n, {True: 'ies', False: 'y'}[n > 1])

if (p.dry_run):
	print "(dry run)"
