#!/usr/bin/env python

import optparse, sys, os, re, time, pprint
import MetagenomeDB as mdb

p = optparse.OptionParser(description = """Part of the MetagenomeDB toolkit.
Imports FASTA alignments into the database.""")

g = optparse.OptionGroup(p, "Input")

g.add_option("-i", "--input", dest = "input_fn", metavar = "FILENAME",
	help = "Results of a FASTA sequence alignment (mandatory). Note: the file MUST be formatted with the '-m 10' option.")

g.add_option("-Q", "--query-collection", dest = "queries_collection", metavar = "STRING",
	help = "Name of the collection the query sequences belong to (mandatory).")

g.add_option("-H", "--hit-collection", dest = "hits_collection", metavar = "STRING",
	help = "Name of the collection the hit sequences belong to (mandatory).")

g.add_option("--date", dest = "date", nargs = 3, type = "int", metavar = "YEAR MONTH DAY",
	help = "Date of the FASTA run (optional). By default, creation date of the input file.")

p.add_option_group(g)

g = optparse.OptionGroup(p, "Input filtering")

g.add_option("--max-E-value", dest = "max_e_value", type = "float", metavar = "FLOAT",
	help = "If set, filter out all hits with a E-value above the provided cut-off.")

g.add_option("--min-identity", dest = "min_identity", type = "int", metavar = "INTEGER",
	help = "If set, filter out all hits with a percentage of identity below the provided cut-off.")

g.add_option("--max-hits", dest = "max_hits", type = "int", metavar = "INTEGER",
	help = "If set, keep only the first '--max-hits' hits for each query.")

g.add_option("--ignore-alignment", dest = "include_alignment", action = "store_false", default = True,
	help = "If set, will not store information about the sequence alignment (HSP coordinates and sequences).")

p.add_option_group(g)

p.add_option("-v", "--verbose", dest = "verbose", action = "store_true", default = False)
p.add_option("--dry-run", dest = "dry_run", action = "store_true", default = False)
p.add_option("--version", dest = "display_version", action = "store_true", default = False)

g = optparse.OptionGroup(p, "Connection")

connection_parameters = {}
def declare_connection_parameter (option, opt, value, parser):
	connection_parameters[opt[2:]] = value

g.add_option("--host", dest = "connection_host", metavar = "HOSTNAME",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Host name or IP address of the MongoDB server (optional). Default:
'host' property in ~/.MetagenomeDB, or 'localhost' if not found.""")

g.add_option("--port", dest = "connection_port", metavar = "INTEGER",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Port of the MongoDB server (optional). Default: 'port' property
in ~/.MetagenomeDB, or 27017 if not found.""")

g.add_option("--db", dest = "connection_db", metavar = "STRING",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Name of the database in the MongoDB server (optional). Default:
'db' property in ~/.MetagenomeDB, or 'MetagenomeDB' if not found.""")

g.add_option("--user", dest = "connection_user", metavar = "STRING",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """User for the MongoDB server connection (optional). Default:
'user' property in ~/.MetagenomeDB, or none if not found.""")

g.add_option("--password", dest = "connection_password", metavar = "STRING",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Password for the MongoDB server connection (optional). Default:
'password' property in ~/.MetagenomeDB, or none if not found.""")

p.add_option_group(g)

(p, a) = p.parse_args()

def error (msg):
	msg = str(msg)
	if msg.endswith('.'):
		msg = msg[:-1]
	print >>sys.stderr, "ERROR: %s." % msg
	sys.exit(1)

if (p.display_version):
	print mdb.version
	sys.exit(0)

if (p.input_fn == None):
	error("A FASTA alignment output file must be provided")

if (not os.path.exists(p.input_fn)):
	error("File '%s' not found" % p.input_fn)

if (p.queries_collection == None) or (p.hits_collection == None):
	error("A collection must be provided for both query and hit sequences")

if (not p.date):
	date = time.localtime(os.path.getmtime(p.input_fn))
	p.date = (date.tm_year, date.tm_mon, date.tm_mday)

else:
	try:
		y, m, d = p.date
		assert (y > 1990), "value '%s' is incorrect for year" % y
		assert (m > 0) and (m < 13), "value '%s' is incorrect for month" % m
		assert (d > 0) and (d < 32), "value '%s' is incorrect for day" % d

	except Exception, msg:
		error("Invalid date: %s" % msg)

if (p.max_e_value):
	if (p.max_e_value < 0):
		error("Invalid E-value cut-off: %s" % p.max_e_value)

if (p.min_identity):
	if (p.min_identity < 0) or (p.min_identity > 100):
		error("Invalid percentage of identity cut-off: %s" % p.min_identity)

if (p.max_hits):
	if (p.max_hits < 0):
		error("Invalid number of hits cut-off: %s" % p.max_hits)

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

if (p.verbose):
	mdb.max_verbosity()

try:
	mdb.connect(**connection_parameters)
except Exception as msg:
	error(msg)

print "Importing '%s' ..." % p.input_fn

print "  validating query and hit sequences ..."

QuerySequences, HitSequences = {}, {}
DuplicateQueries, DuplicateHits = {}, {}

try:
	queries = mdb.Collection.find_one({"name": p.queries_collection})
	if (queries == None):
		error("Unknown collection '%s'" % p.queries_collection)

	for sequence in queries.list_sequences():
		sequence_name = str(sequence["name"]) 
		if (sequence_name in QuerySequences):
			DuplicateQueries[sequence_name] = True

		QuerySequences[sequence_name] = sequence

	hits = mdb.Collection.find_one({"name": p.hits_collection})
	if (hits == None):
		error("Unknown collection '%s'" % p.hits_collection)

	for sequence in hits.list_sequences():
		sequence_name = str(sequence["name"]) 
		if (sequence_name in HitSequences):
			DuplicateHits[sequence_name] = True

		HitSequences[sequence_name] = sequence

except Exception as msg:
	error(msg)

fh = open(p.input_fn, 'r')

DB_DIMENSIONS = re.compile("\s*([0-9]+) residues in\s*([0-9]+) sequences\n")
QUERY_HEADER = re.compile(">>>(.*?), [0-9]+ nt vs (.*?) library\n")
HIT_HEADER = re.compile(">>([^ .]*).*\n")
KEY_VALUE = re.compile("; ([a-z]{2}_[a-zA-Z\-_0-9]+):(.*)\n")

previous = None
have_statistics = False

HSP = {}

while True:
	line = fh.readline()
	if (line == ''):
		break

	if (not have_statistics) and line.startswith("Statistics:"):
		m = DB_DIMENSIONS.match(previous)
		assert (m != None), previous
		n_residues, n_sequences = int(m.group(1)), int(m.group(2))

#		print "  database: %s residues, %s sequences" % (n_residues, n_sequences)
		have_statistics = True

	if (line == ">>><<<\n"):
		continue

	# new query
	if line.startswith(">>>"):
		m = QUERY_HEADER.match(line)
		assert (m != None), line
		query_id, database = m.group(1), os.path.basename(m.group(2))

		if (not query_id in QuerySequences):
			error("Unknown query sequence '%s'" % query_id)

		if (query_id in DuplicateQueries):
			error("Duplicate query sequence '%s'" % query_id)

		run = {"database": database}
		hsp = {}
		block_n = 0
		hit_n = 0
		line_n = 0

		while True:
			line = fh.readline()
			if (line == '') or (line == "\n") or (line == ">>><<<\n"):
				break

			elif (line.startswith(">>")):
				m = HIT_HEADER.match(line)
				assert (m != None), line

				hit_id = m.group(1)
				hit_n += 1
				hit_key = (hit_n, hit_id)

				if (not hit_id in HitSequences):
					error("Unknown hit sequence '%s'" % hit_id)

				if (hit_id in DuplicateHits):
					error("Duplicate hit sequence '%s'" % hit_id)

				block_n += 1

			elif (line.startswith(">")):
#				assert (line == ">%s ..\n" % query_id) or (line == ">%s ..\n" % hit_id), line
				block_n += 1

			else:
				if (line[0] == ';'):
					m = KEY_VALUE.match(line)
					assert (m != None), line

					key, value = m.group(1), m.group(2).strip()
					last_key = key
				else:
					key, value = (last_key, line_n), line.rstrip('\n')

				# run
				if (block_n == 0):
					run[key] = value

				# hsp
				elif ((block_n - 1) % 3 == 0):
					if (not hit_key in hsp):
						hsp[hit_key] = {"query": {}, "hit": {}}

					hsp[hit_key][key] = value

				# query in hsp
				elif ((block_n - 2) % 3 == 0):
					hsp[hit_key]["query"][key] = value

				# hit in hsp
				elif ((block_n - 3) % 3 == 0):
					hsp[hit_key]["hit"][key] = value

			line_n += 1

		HSP[query_id] = (run, hsp)

	previous = line

print "  importing HSPs ..."

for query_id in HSP:
	run, hits = HSP[query_id]
	query_o = QuerySequences[query_id]

	parameters = {}
	for key in filter(lambda x: x not in ("pg_name", "pg_ver", "pg_name_alg", "pg_ver_rel", "database", "mp_Algorithm"), run):
		parameters[key] = run[key]

	m = 0
	for hit_key in hits:
		hit_id = hit_key[1]
		hit_o = HitSequences[hit_id]
		
		hsp = hits[hit_key]

		identity = 100 * float(hsp["bs_ident"])
		e_value = float(hsp["fa_expect"])

		m += 1
		if (p.max_hits) and (m > p.max_hits):
			break

		if (p.min_identity) and (identity < p.min_identity):
			continue

		if (p.max_e_value) and (e_value > p.max_e_value):
			continue

		r = {
			"type": "similar-to",

			"run": {
				"date": {"year": p.date[0], "month": p.date[1], "day": p.date[2]},
				"algorithm": {
					"name": run["pg_name"],
					"version": run["pg_ver"],
					"parameters": parameters,
				},
				"database": {
					"name": run["database"],
					"number_of_sequences": n_sequences,
					"number_of_letters": n_residues,
				}
			},

			"score": {
				"percent_identity": identity,
				"percent_similarity": 100 * float(hsp["bs_sim"]),
				"e_value": e_value,
			}
		}

		if (p.include_alignment):
			query_alignment = ''.join([hsp["query"][key] for key in sorted(filter(lambda x: (type(x) == tuple) and (x[0] == "al_display_start"), hsp["query"]), lambda x, y: cmp(x[1], y[1]))])
			hit_alignment = ''.join([hsp["hit"][key] for key in sorted(filter(lambda x: (type(x) == tuple) and (x[0] == "al_display_start"), hsp["hit"]), lambda x, y: cmp(x[1], y[1]))])
			cons_alignment = ''.join([hsp["hit"][key] for key in sorted(filter(lambda x: (type(x) == tuple) and (x[0] == "al_cons"), hsp["hit"]), lambda x, y: cmp(x[1], y[1]))])

			source_coordinates = (int(hsp["query"]["al_start"]), int(hsp["query"]["al_stop"]))
			target_coordinates = (int(hsp["hit"]["al_start"]), int(hsp["hit"]["al_stop"]))

			source_offset = int(hsp["query"]["al_display_start"])
			target_offset = int(hsp["hit"]["al_display_start"])

			offset1 = abs(source_offset - source_coordinates[0])
			offset2 = abs(target_offset - target_coordinates[0])

			start = max(offset1, offset2)
			stop = len(cons_alignment.strip()) + start
			assert (cons_alignment[start] != ' ') ###

			r["alignment"] = {
				"source_coordinates": source_coordinates,
				"source": query_alignment[start:stop],
				"match": cons_alignment[start:stop],
				"target": hit_alignment[start:stop],
				"target_coordinates": target_coordinates,
			}

		if (p.dry_run):
			print "    query '%s' to hit '%s'" % (query_id, hit_id)
			for line in pprint.pformat(r).split('\n'):
				print "      %s" % line
		else:
			query_o.relate_to_sequence(hit_o, r)
			query_o.commit()

print "    done."

if (p.dry_run):
	print "(dry run)"
