#!/usr/bin/env python

import optparse, sys, os, re, time, pprint, itertools
import MetagenomeDB as mdb

p = optparse.OptionParser(description = """Part of the MetagenomeDB toolkit.
Imports ACE-formatted mapping between reads and contigs into the database.""")

g = optparse.OptionGroup(p, "Input")

g.add_option("-i", "--input", dest = "input_fn", metavar = "FILENAME",
	help = "ACE file (mandatory).")

g.add_option("--reads-collection", dest = "reads_collection_names", metavar = "STRING", action = "append",
	help = """Name of the collection the reads belong to (mandatory). More than
one reads collection can be provided by using this option multiple times.""")

g.add_option("--contigs-collection", dest = "contigs_collection_name", metavar = "STRING",
	help = "Name of the collection the contigs belong to (mandatory).")

g.add_option("--reads-mapping", dest = "reads_mapping_fn", metavar = "STRING",
	help = """Tab-delimited file with, for each read mentioned in the ACE file,
the name of this read in a reads collection and the name of this collection.
Cannot be used in combination with --reads-collection.""")

g.add_option("--date", dest = "date", nargs = 3, type = "int", metavar = "YEAR MONTH DAY",
	help = "Date of the assembly (optional). By default, creation date of the ACE file.")

g.add_option("--read-id-getter", dest = "read_id_getter", metavar = "PYTHON CODE", default = "%",
	help = """Python code to reformat read identifiers (optional); '%s' will be
replaced by a Biopython ACE read object. Default: %default.""")

g.add_option("--contig-id-getter", dest = "contig_id_getter", metavar = "PYTHON CODE", default = "%",
	help = """Python code to reformat contigs identifiers (optional); '%s' will
be replaced by a Biopython ACE record object. Default: %default.""")

p.add_option_group(g)

g = optparse.OptionGroup(p, "Filtering")

g.add_option("--ignore-alignment", dest = "include_alignment", action = "store_false", default = True,
	help = "If set, will not store HSP sequences and conservation lines.")

g.add_option("--ignore-consensus", dest = "include_consensus", action = "store_false", default = True,
	help = "If set, will not store the contig consensus sequence.")

p.add_option_group(g)

g = optparse.OptionGroup(p, "Errors handling")

g.add_option("--ignore-missing-reads", dest = "ignore_missing_reads", action = "store_true", default = False,
	help = "If set, ignore reads that are not found in the reads collection.")

g.add_option("--ignore-missing-contigs", dest = "ignore_missing_contigs", action = "store_true", default = False,
	help = "If set, ignore contigs that are not found in the contigs collection.")

g.add_option("--ignore-duplicates", dest = "ignore_duplicates", action = "store_true", default = False,
	help = "If set, ignore duplicate objects errors.")

g.add_option("--ignore-large-entries", dest = "ignore_large_entries", action = "store_true", default = False,
	help = """If set, ignore cases where a large amount of contigs being associated
to a given read would result in this read to be too large for the database.""")

p.add_option_group(g)

p.add_option("-v", "--verbose", dest = "verbose", action = "store_true", default = False)
p.add_option("--no-progress-bar", dest = "display_progress_bar", action = "store_false", default = True)
p.add_option("--dry-run", dest = "dry_run", action = "store_true", default = False)
p.add_option("--version", dest = "display_version", action = "store_true", default = False)

g = optparse.OptionGroup(p, "Connection")

connection_parameters = {}
def declare_connection_parameter (option, opt, value, parser):
	connection_parameters[opt[2:]] = value

g.add_option("--host", dest = "connection_host", metavar = "HOSTNAME",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Host name or IP address of the MongoDB server (optional). Default:
'host' property in ~/.MetagenomeDB, or 'localhost' if not found.""")

g.add_option("--port", dest = "connection_port", metavar = "INTEGER",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Port of the MongoDB server (optional). Default: 'port' property
in ~/.MetagenomeDB, or 27017 if not found.""")

g.add_option("--db", dest = "connection_db", metavar = "STRING",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Name of the database in the MongoDB server (optional). Default:
'db' property in ~/.MetagenomeDB, or 'MetagenomeDB' if not found.""")

g.add_option("--user", dest = "connection_user", metavar = "STRING",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """User for the MongoDB server connection (optional). Default:
'user' property in ~/.MetagenomeDB, or none if not found.""")

g.add_option("--password", dest = "connection_password", metavar = "STRING",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Password for the MongoDB server connection (optional). Default:
'password' property in ~/.MetagenomeDB, or none if not found.""")

p.add_option_group(g)

(p, a) = p.parse_args()

def error (msg):
	msg = str(msg)
	if msg.endswith('.'):
		msg = msg[:-1]
	print >>sys.stderr, "ERROR: %s." % msg
	sys.exit(1)

if (p.display_version):
	print mdb.version
	sys.exit(0)

if (p.input_fn == None):
	error("An ACE file must be provided")

if (not os.path.exists(p.input_fn)):
	error("File '%s' not found" % p.input_fn)

if (not p.date):
	date = time.localtime(os.path.getmtime(p.input_fn))
	p.date = (date.tm_year, date.tm_mon, date.tm_mday)

else:
	try:
		y, m, d = p.date
		assert (y > 1990), "value '%s' is incorrect for year" % y
		assert (m > 0) and (m < 13), "value '%s' is incorrect for month" % m
		assert (d > 0) and (d < 32), "value '%s' is incorrect for day" % d

	except Exception, msg:
		error("Invalid date: %s" % msg)

if (p.reads_collection_names != None) and (p.reads_mapping_fn != None):
	error("--reads-collection and --reads-mapping cannot be used simultaneously")

if (p.reads_collection_names == None) and (p.reads_mapping_fn == None):
	error("A collection must be provided for both reads and contigs")

if (p.reads_mapping_fn != None) and (not os.path.exists(p.reads_mapping_fn)):
	error("File '%s' not found" % p.reads_mapping_fn)

if (p.contigs_collection_name == None):
	error("A collection must be provided for both reads and contigs")

try:
	get_read_id = eval("lambda x: " + p.read_id_getter.replace('%', "x.rd.name").replace("\\x", '%'))
	get_contig_id = eval("lambda x: " + p.contig_id_getter.replace('%', "x.name").replace("\\x", '%'))

except SyntaxError as e:
	error("Invalid getter: %s\n%s^" % (e.text, ' ' * (e.offset + 22)))

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

try:
	from Bio.Sequencing import Ace
except:
	error("The Biopython library must be installed\nTry 'easy_install Biopython'")

if (p.verbose):
	mdb.max_verbosity()

try:
	mdb.connect(**connection_parameters)
except Exception as msg:
	error(msg)

print "Importing '%s' ..." % p.input_fn

try:
	# test reads collections provided as arguments
	if (p.reads_collection_names):
		reads_collections = []
		for reads_collection_name in p.reads_collection_names:
			reads_collection = mdb.Collection.find_one({"name": reads_collection_name})
			if (reads_collection == None):
				error("Unknown reads collection '%s'" % reads_collection_name)

			reads_collections.append(reads_collection)

		mapping = None

	# test reads collections mentioned in the mapping file
	else:
		print "  loading '%s' ..." % p.reads_mapping_fn

		reads_collections, mapping = {}, {}

		fh = open(p.reads_mapping_fn, 'rU')
		while True:
			line = fh.readline()
			if (line == ''):
				break
			line = line.strip()
			if (line == ''):
				continue

			items = line.split()
			if (len(items) != 3):
				error("Invalid mapping file. Line was \"%s\"" % line)

			current_read_id, original_read_id, reads_collection_name = items

			if (current_read_id in mapping):
				error("Duplicate mapping for read '%s'" % current_read_id)

			if (reads_collection_name in reads_collections):
				reads_collection = reads_collections[reads_collection_name]
			else:
				reads_collection = mdb.Collection.find_one({"name": reads_collection_name})
				if (reads_collection == None):
					error("Unknown reads collection '%s'" % reads_collection_name)

				reads_collections[reads_collection_name] = reads_collection

			mapping[current_read_id] = (original_read_id, reads_collection)

		del reads_collections

	# test the contigs collection
	contigs_collection = mdb.Collection.find_one({"name": p.contigs_collection_name})
	if (contigs_collection == None):
		error("Unknown contigs collection '%s'" % p.contigs_collection_name)

except mdb.errors.DBConnectionError as msg:
	error(msg)

print "  validating read and contig sequences ..."

n = 0
for contig in Ace.parse(open(p.input_fn, 'r')):
	contig_id = get_contig_id(contig)

	candidates = list(contigs_collection.list_sequences({"name": contig_id}))

	if (len(candidates) == 0):
		msg = "Unknown contig '%s'" % contig_id
		if (p.ignore_missing_contigs):
			print >>sys.stderr, "WARNING: " + msg
			continue
		else:
			error(msg)

	if (len(candidates) > 1):
		error("Ambiguous contig '%s'" % contig_id)

	for read in contig.reads:
		read_id = get_read_id(read)

		if (mapping == None):
			candidates = list(itertools.chain(*[reads_collection.list_sequences({"name": read_id}) for reads_collection in reads_collections]))

		elif (read_id not in mapping):
			candidates = []

		else:
			read_id, reads_collection = mapping[read_id]
			candidates = list(reads_collection.list_sequences({"name": read_id}))

		if (len(candidates) == 0):
			msg = "Unknown read '%s' (mapped to contig '%s')" % (read_id, contig_id)
			if (p.ignore_missing_reads):
				print >>sys.stderr, "WARNING: " + msg
				continue
			else:
				error(msg)

		if (len(candidates) > 1):
			error("Ambiguous read '%s'" % read_id)

		n += 1

if (n == 0):
	error("No mapping in the input")

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

print "  importing mapping ..."

class ProgressBar:
	def __init__ (self, upper = None):
		self.__min = 0.0
		self.__max = upper + 0.0

	def display (self, value):
		f = (value - self.__min) / (self.__max - self.__min) # fraction
		p = 100 * f # percentage
		s = int(round(80 * f)) # bar size

		sys.stdout.write(' ' * 2 + ('.' * s) + " %4.2f%%\r" % p)
		sys.stdout.flush()

	def clear (self):
		sys.stdout.write(' ' * (2 + 80 + 8) + "\r")
		sys.stdout.flush()

pb = ProgressBar(n)
n = 0

def get_sequence (collection, name, is_list = False):
	if (is_list):
		candidates = list(itertools.chain(*[collection_.list_sequences({"name": name}) for collection_ in collection]))
	else:
		candidates = list(collection.list_sequences({"name": name}))

	if (candidates == []):
		return None
	else:
		return candidates[0]

# documentation for ACE file format: http://bcr.musc.edu/manuals/CONSED.txt
# see also http://www.cbcb.umd.edu/research/contig_representation.shtml#ACE

for contig in Ace.parse(open(p.input_fn, 'r')):
	contig_id = get_contig_id(contig)
	contig_o = get_sequence(contigs_collection, contig_id)

	if (contig_o is None):
		continue

	contig_complemented = (contig.uorc == "C")
	contig_sequence = contig.sequence.upper()

	for read_idx, read in enumerate(contig.reads):
		read_id = get_read_id(read)

		if (mapping != None) and (read_id in mapping):
			read_id, reads_collection = mapping[read_id]
			read_o = get_sequence(reads_collection, read_id)
		else:
			read_o = get_sequence(reads_collections, read_id, is_list = True)

		if (read_o is None):
			continue

		read_complemented = (contig.af[read_idx].coru == "C")
		read_sequence = read.rd.sequence.upper()

		read_start = read.qa.align_clipping_start
		read_stop = read.qa.align_clipping_end

		if (read_complemented):
			read_start_, read_stop_ = read_stop, read_start
		else:
			read_start_, read_stop_ = read_start, read_stop

		offset = contig.af[read_idx].padded_start
		if (offset < 0):
			contig_start = 1
		else:
			contig_start = offset + read_start - 1

		contig_stop = contig_start + (read_stop - read_start)

		if (contig_complemented):
			contig_start_, contig_stop_ = contig_stop, contig_start
		else:
			contig_start_, contig_stop_ = contig_start, contig_stop

		r = {
			"type": "similar-to",
			"run": {
				"date": {"year": p.date[0], "month": p.date[1], "day": p.date[2]}
			},
			"alignment": {
				"source_coordinates": (read_start_, read_stop_),
				"target_coordinates": (contig_start_, contig_stop_)
			}
		}

		if (p.include_consensus):
			r["alignment"]["target_consensus"] = contig_sequence

		if (p.include_alignment):
			source = contig_sequence[contig_start-1:contig_stop].replace('*', '-')
			target = read_sequence[read_start-1:read_stop].replace('*', '-')
			match = ''
			for i, c in enumerate(source):
				if (c == '-') or (target[i] == '-'):
					match += ' '
				elif (c == target[i]):
					match += ':'
				else:
					match += ' '

			r["alignment"]["source"] = source
			r["alignment"]["match"] = match
			r["alignment"]["target"] = target

		if (p.dry_run):
			print "    read '%s' to contig '%s'" % (read_id, contig_id)
			for line in pprint.pformat(r).split('\n'):
				print "      %s" % line
		else:
			try:
				read_o.relate_to_sequence(contig_o, r)
				read_o.commit()

			except mdb.errors.DuplicateObjectError as msg:
				if (p.ignore_duplicates):
					print >>sys.stderr, "WARNING: %s" % str(msg)
				else:
					error(msg)

			except mdb.errors.DBOperationError as msg:
				if ("too large" in str(msg)):
					if (p.ignore_large_entries):
						print >>sys.stderr, "WARNING: Too many contigs for read '%s'; this information will be ignored." % read_id
						continue
					else:
						error("Too many contigs for read '%s'" % read_id)
				else:
					error(msg)

		if (p.display_progress_bar):
			pb.display(n)

		n += 1

if (p.display_progress_bar):
	pb.clear()

print "    done."

if (p.dry_run):
	print "(dry run)"
