#!/usr/bin/env python

import optparse, sys, os, pprint, hashlib, time

p = optparse.OptionParser(description = """Part of the MetagenomeDB toolkit.
Imports CRISPR annotations generated by CRISPRfinder into the database.""")

g = optparse.OptionGroup(p, "Input")

g.add_option("-i", "--input", dest = "input_fn", metavar = "FILENAME",
	help = "CRISPR annotations to import (mandatory). Must be the 'AnnotFasta' file produced by CRISPRfinder.")

g.add_option("-c", "--collection", dest = "sequences_collection_name", metavar = "STRING",
	help = "Name of the collection the annotated sequences belong to (mandatory).")

g.add_option("--id-getter", dest = "id_getter", metavar = "PYTHON CODE", default = "%",
	help = "Python code to reformat sequence identifers (optional); '%' will be replaced by the sequence identifier. Default: %default")

g.add_option("--id-patches", dest = "id_patches_fn", metavar = "FILENAME",
	help = """Tab-delimited text files providing alternative sequence identifiers
(optional). The first column should be the identifier found in --input, and the
second column the identifier to consider for this sequence.""")

g.add_option("--date", dest = "date", nargs = 3, type = "int", metavar = "YEAR MONTH DAY",
	help = "Date of the CRISPRfinder run (optional). By default, creation date of the input file.")

p.add_option_group(g)

g = optparse.OptionGroup(p, "Collections")

g.add_option("-C", "--CRISPRs-collection", dest = "CRISPRs_collection_name", metavar = "STRING", default = "CRISPRs",
	help = "Name of the collection the CRISPRs belong to (optional). Default: '%default'")

g.add_option("-S", "--spacers-collection", dest = "spacers_collection_name", metavar = "STRING", default = "Spacers",
	help = "Name of the collection the spacers belong to (optional). Default: '%default'")

g.add_option("-D", "--DRs-collection", dest = "repeats_collection_name", metavar = "STRING", default = "DirectRepeats",
	help = "Name of the collection the direct repeats belong to (optional). Default: '%default'")

p.add_option_group(g)

p.add_option("-v", "--verbose", dest = "verbose", action = "store_true", default = False)
p.add_option("--dry-run", dest = "dry_run", action = "store_true", default = False)

g = optparse.OptionGroup(p, "Connection")

g.add_option("--host", dest = "connection_host", metavar = "HOSTNAME", default = "localhost",
	help = "Host name or IP address of the MongoDB server (optional). Default: %default")

g.add_option("--port", dest = "connection_port", metavar = "INTEGER", default = 27017,
	help = "Port of the MongoDB server (optional). Default: %default")

g.add_option("--db", dest = "connection_db", metavar = "STRING", default = "MetagenomeDB",
	help = "Name of the database in the MongoDB server (optional). Default: '%default'")

g.add_option("--user", dest = "connection_user", metavar = "STRING", default = '',
	help = "User for the MongoDB server connection (optional). Default: '%default'")

g.add_option("--password", dest = "connection_password", metavar = "STRING", default = '',
	help = "Password for the MongoDB server connection (optional). Default: '%default'")

p.add_option_group(g)

(p, a) = p.parse_args()

def error (msg):
	if str(msg).endswith('.'):
		msg = str(msg)[:-1]
	print >>sys.stderr, "ERROR: %s." % msg
	sys.exit(1)

if (p.input_fn == None):
	error("An annotation file must be provided")

if (not os.path.exists(p.input_fn)):
	error("File '%s' not found" % p.input_fn)

if (not p.sequences_collection_name):
	error("A collection name must be provided")

try:
	get_sequence_id = eval("lambda x: " + p.id_getter.replace('%', 'x'))

except SyntaxError, e:
	error("Invalid getter: %s\n%s^" % (e.text, ' ' * (e.offset + 22)))

Patch = {}
if (p.id_patches_fn):
	if (not os.path.exists(p.id_patches_fn)):
		error("File '%s' not found" % p.id_patches_fn)

	o = open(p.id_patches_fn, 'r')

	while True:
		line = o.readline()
		if (line == ''):
			break

		line = line.strip()
		if (line == '') or (line.startswith('#')):
			continue

		line = line.split('\t')

		if (len(line) != 2):
			error("Malformated patch file '%s': the file should contains only two columns" % p.id_patches_fn)

		original_id, alternative_id = line
		if (original_id in Patch):
			error("Duplicate original identifier '%s' found in '%s'" % (original_id, p.id_patches_fn))

		Patch[original_id] = alternative_id

if (not p.date):
	date = time.localtime(os.path.getmtime(p.input_fn))
	p.date = (date.tm_year, date.tm_mon, date.tm_mday)

else:
	try:
		y, m, d = p.date
		assert (y > 1990), "value '%s' is incorrect for year" % y
		assert (m > 0) and (m < 13), "value '%s' is incorrect for month" % m
		assert (d > 0) and (d < 32), "value '%s' is incorrect for day" % d

	except Exception, msg:
		error("Invalid date: %s" % msg)

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

import MetagenomeDB as mdb

if (p.verbose):
	mdb.max_verbosity()

if (p.connection_host or p.connection_port or p.connection_db or p.connection_user or p.connection_password):
	try:
		mdb.connect(p.connection_host, p.connection_port, p.connection_db, p.connection_user, p.connection_password)
	except Exception as msg:
		error(msg)

collection = mdb.Collection.find_one({"name": p.sequences_collection_name})

if (collection == None):
	error("Unknown collection '%s'" % p.sequences_collection_name)

Collections = {
	"CRISPR": p.CRISPRs_collection_name,
	"Spacer": p.spacers_collection_name,
	"DR": p.repeats_collection_name,
}

if (not p.dry_run):
	for c_class in Collections:
		c_name = Collections[c_class]

		c = mdb.Collection.find_one({"name": c_name})

		if (c != None):
			s = {}
			for sequence in c.list_sequences():
				s[str(sequence["name"])] = sequence
		else:
			c = mdb.Collection({"name": c_name, "class": "%ss" % c_class})
			c.commit()
			s = {}
		
		Collections[c_class] = (c, s)

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

def clean_sequence (sequence):
	sequence_ = ''
	for c in sequence:
		if (c.lower() in "atgc"):
			sequence_ += c
	return sequence_

def parser (fn):
	fh = open(fn, 'rU')
	sequence_id = None

	while True:
		line = fh.readline()
		if (line == ''):
			break

		line = line.strip()

		if (line.startswith('>')):
			sequence_id = get_sequence_id(line[1:])
			sequence_id = Patch.get(sequence_id, sequence_id)

			CRISPR, DRs, spacers = '', [], []

		else:
			items = line.split()
			n_items = len(items)

			if (n_items == 1):
				DRs.append(clean_sequence(items[0]))
				CRISPR += DRs[-1]

				# deduplicate the direct repeats
				DRs = {}.fromkeys(DRs).keys()

				yield (sequence_id, CRISPR, DRs, spacers)

			elif (n_items == 3):
				DRs.append(clean_sequence(items[0]))
				CRISPR += DRs[-1]
				
				spacers.append(clean_sequence(items[1]))
				CRISPR += spacers[-1]

			else:
				error("Malformed input: '%s'" % line)

print "Importing '%s' ..." % p.input_fn

print "  validating sequences ..."

Sequences, DuplicateSequences = {}, {}

for sequence in collection.list_sequences():
	sequence_name = str(sequence["name"])
	if (sequence_name in Sequences):
		DuplicateSequences[sequence_name] = True

	Sequences[sequence_name] = sequence

seen = {}
for (sequence_id, CRISPR, DRs, spacers) in parser(p.input_fn):
	if (sequence_id in seen):
		error("Duplicate sequence '%s' in input" % sequence_id)

	seen[sequence_id] = True

	if (not sequence_id in Sequences):
		error("Unknown sequence '%s'" % sequence_id)

	if (sequence_id in DuplicateSequences):
		error("Duplicate sequence '%s'" % sequence_id)

print "  importing annotations ..."

class ProgressBar:
	def __init__ (self, upper = None):
		self.__min = 0.0
		self.__max = upper + 0.0

	def display (self, value):
		f = (value - self.__min) / (self.__max - self.__min) # fraction
		p = 100 * f # percentage
		s = int(round(80 * f)) # bar size

		sys.stdout.write(' ' * 2 + ('.' * s) + " %4.2f%%\r" % p)
		sys.stdout.flush()

	def clear (self):
		sys.stdout.write(' ' * (2 + 80 + 8) + "\r")
		sys.stdout.flush()

def hash_CRISPR_sequence (sequence):
	return hashlib.md5(sequence.upper()).hexdigest()

def declare_CRISPR (sequence):
	CRISPRs_c, CRISPRs_s = Collections["CRISPR"]

	CRISPR_id = hash_CRISPR_sequence(sequence)

	if (CRISPR_id in CRISPRs_s):
		return CRISPRs_s[CRISPR_id]

	CRISPR = mdb.Sequence({"name": CRISPR_id, "sequence": sequence, "class": "CRISPR"})
	CRISPR.add_to_collection(CRISPRs_c)
	CRISPR.commit()

	CRISPRs_s[CRISPR_id] = CRISPR
	return CRISPR

def hash_DR_or_spacer_sequence (sequence):
	# find the reverse complement
	sequence_ = ''
	for c in sequence.upper():
		sequence_ += {
			'A': 'T',
			'T': 'A',
			'G': 'C',
			'C': 'G',
		}[c]

	return sorted((sequence, sequence_))[0]

def declare_DR (sequence):
	DRs_c, DRs_s = Collections["DR"]

	DR_id = hash_DR_or_spacer_sequence(sequence)

	if (DR_id in DRs_s):
		return DRs_s[DR_id]

	DR = mdb.Sequence({"name": DR_id, "sequence": sequence, "class": "direct repeat"})
	DR.add_to_collection(DRs_c)
	DR.commit()

	DRs_s[DR_id] = DR
	return DR

def declare_spacer (sequence):
	spacers_c, spacers_s = Collections["Spacer"]

	spacer_id = hash_DR_or_spacer_sequence(sequence)

	if (spacer_id in spacers_s):
		return spacers_s[spacer_id]

	spacer = mdb.Sequence({"name": spacer_id, "sequence": sequence, "class": "spacer"})
	spacer.add_to_collection(spacers_c)
	spacer.commit()

	spacers_s[spacer_id] = spacer
	return spacer

show_pb = (not p.dry_run)

N, n = len(seen), 0
pb = ProgressBar(N)

if (show_pb):
	pb = ProgressBar(N)

for (sequence_id, CRISPR, DRs, spacers) in parser(p.input_fn):
	sequence = Sequences[sequence_id]

	r = {
		"type": "part-of",

		"run": {
			"date": {"year": p.date[0], "month": p.date[1], "day": p.date[2]},
			"algorithm": {
				"name": "CRISPRfinder"
			}
		}
	}

	if (not p.dry_run):
		CRISPR = declare_CRISPR(CRISPR)

		seen_DR = {}
		for DR in DRs:
			DR = declare_DR(DR)
			if (DR in seen_DR):
				continue

			# DR -> CRISPR
			DR.relate_to_sequence(CRISPR, r)
			DR.commit()
			seen_DR[DR] = True

		seen_spacer = {}
		for spacer in spacers:
			spacer = declare_spacer(spacer)
			if (spacer in seen_spacer):
				continue

			# Spacer -> CRISPR
			spacer.relate_to_sequence(CRISPR, r)
			spacer.commit()
			seen_spacer[spacer] = True

	if (p.dry_run):
		print "    CRISPR '%s...' (%s spacers) to sequence '%s'" % (CRISPR[:10], len(spacers), sequence_id)
		for line in pprint.pformat(r).split('\n'):
			print "      %s" % line

	else:
		# CRISPR -> Sequence
		CRISPR.relate_to_sequence(sequence, r)
		CRISPR.commit()

	if (show_pb):
		pb.display(n)

	n += 1

if (show_pb):
	pb.clear()

print "  %s sequences annotated." % n

if (p.dry_run):
	print "(dry run)"
