#!/usr/bin/env python

import optparse, sys, os, re, pprint, hashlib, time
import MetagenomeDB as mdb

p = optparse.OptionParser(description = """Part of the MetagenomeDB toolkit.
Imports CRISPR annotations generated by CRISPRfinder into the database.""")

g = optparse.OptionGroup(p, "Input")

g.add_option("-i", "--input", dest = "input_fn", metavar = "FILENAME",
	help = "CRISPR annotations to import (mandatory). Must be the 'AnnotFasta' file produced by CRISPRfinder.")

g.add_option("-C", "--collection", dest = "sequences_collection_name", metavar = "STRING",
	help = "Name of the collection that contains the sequences to annotate (mandatory).")

g.add_option("--id-getter", dest = "id_getter", metavar = "PYTHON CODE", default = "%",
	help = "Python code to reformat sequence identifers (optional); '%' will be replaced by the sequence identifier. Default: %default")

g.add_option("--id-patches", dest = "id_patches_fn", metavar = "FILENAME",
	help = """Tab-delimited text files providing alternative sequence identifiers
(optional). The first column should be the identifier found in the input file,
and the second column the identifier to consider for this sequence.""")

g.add_option("--CRISPR-property", dest = "CRISPR_properties", nargs = 2, action = "append", metavar = "KEY VALUE",
	help = "CRISPRs property (optional).")

g.add_option("--DR-property", dest = "DR_properties", nargs = 2, action = "append", metavar = "KEY VALUE",
	help = "Direct repeats property (optional).")

g.add_option("--spacer-property", dest = "spacer_properties", nargs = 2, action = "append", metavar = "KEY VALUE",
	help = "Spacers property (optional).")

g.add_option("--date", dest = "date", nargs = 3, type = "int", metavar = "YEAR MONTH DAY",
	help = "Date of the CRISPRfinder run (optional). By default, creation date of the input file.")

p.add_option_group(g)

g = optparse.OptionGroup(p, "Collections")

g.add_option("-c", "--CRISPRs-collection", dest = "CRISPRs_collection_name", metavar = "STRING", default = "CRISPRs",
	help = "Name of the collection the CRISPRs belong to (optional). Default: '%default'")

g.add_option("-d", "--DRs-collection", dest = "repeats_collection_name", metavar = "STRING", default = "DirectRepeats",
	help = "Name of the collection the direct repeats belong to (optional). Default: '%default'")

g.add_option("-s", "--spacers-collection", dest = "spacers_collection_name", metavar = "STRING", default = "Spacers",
	help = "Name of the collection the spacers belong to (optional). Default: '%default'")

p.add_option_group(g)

p.add_option("-v", "--verbose", dest = "verbose", action = "store_true", default = False)
p.add_option("--no-progress-bar", dest = "display_progress_bar", action = "store_false", default = True)
p.add_option("--dry-run", dest = "dry_run", action = "store_true", default = False)
p.add_option("--version", dest = "display_version", action = "store_true", default = False)

g = optparse.OptionGroup(p, "Connection")

connection_parameters = {}
def declare_connection_parameter (option, opt, value, parser):
	connection_parameters[opt[2:]] = value

g.add_option("--host", dest = "connection_host", metavar = "HOSTNAME",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Host name or IP address of the MongoDB server (optional). Default:
'host' property in ~/.MetagenomeDB, or 'localhost' if not found.""")

g.add_option("--port", dest = "connection_port", metavar = "INTEGER",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Port of the MongoDB server (optional). Default: 'port' property
in ~/.MetagenomeDB, or 27017 if not found.""")

g.add_option("--db", dest = "connection_db", metavar = "STRING",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Name of the database in the MongoDB server (optional). Default:
'db' property in ~/.MetagenomeDB, or 'MetagenomeDB' if not found.""")

g.add_option("--user", dest = "connection_user", metavar = "STRING",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """User for the MongoDB server connection (optional). Default:
'user' property in ~/.MetagenomeDB, or none if not found.""")

g.add_option("--password", dest = "connection_password", metavar = "STRING",
	type = "string", action = "callback", callback = declare_connection_parameter,
	help = """Password for the MongoDB server connection (optional). Default:
'password' property in ~/.MetagenomeDB, or none if not found.""")

p.add_option_group(g)

(p, a) = p.parse_args()

def error (msg):
	msg = str(msg)
	if msg.endswith('.'):
		msg = msg[:-1]
	print >>sys.stderr, "ERROR: %s." % msg
	sys.exit(1)

if (p.display_version):
	print mdb.version
	sys.exit(0)

if (p.input_fn == None):
	error("An annotation file must be provided")

if (not os.path.exists(p.input_fn)):
	error("File '%s' not found" % p.input_fn)

if (not p.sequences_collection_name):
	error("A collection name must be provided")

try:
	get_sequence_id = eval("lambda x: " + p.id_getter.replace('%', 'x').replace("\\x", '%'))

except SyntaxError as e:
	error("Invalid getter: %s\n%s^" % (e.text, ' ' * (e.offset + 22)))

Patch = {}
if (p.id_patches_fn):
	if (not os.path.exists(p.id_patches_fn)):
		error("File '%s' not found" % p.id_patches_fn)

	o = open(p.id_patches_fn, 'r')

	while True:
		line = o.readline()
		if (line == ''):
			break

		line = line.strip()
		if (line == '') or (line.startswith('#')):
			continue

		line = line.split('\t')
		if (len(line) != 2):
			error("Malformated patch file '%s': the file should contains only two columns" % p.id_patches_fn)

		original_id, alternative_id = [item.strip() for item in line]
		if (original_id in Patch):
			error("Duplicate original identifier '%s' found in '%s'" % (original_id, p.id_patches_fn))

		Patch[original_id] = alternative_id

if (not p.date):
	date = time.localtime(os.path.getmtime(p.input_fn))
	p.date = (date.tm_year, date.tm_mon, date.tm_mday)

else:
	try:
		y, m, d = p.date
		assert (y > 1990), "value '%s' is incorrect for year" % y
		assert (m > 0) and (m < 13), "value '%s' is incorrect for month" % m
		assert (d > 0) and (d < 32), "value '%s' is incorrect for day" % d

	except Exception, msg:
		error("Invalid date: %s" % msg)

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

if (p.verbose):
	mdb.max_verbosity()

try:
	mdb.connect(**connection_parameters)
except Exception as msg:
	error(msg)

def parse_properties (properties):
	properties_ = []
	for (key, value) in properties:
		if (key.lower() in ("name", "sequence", "length")):
			error("Property '%s' cannot be modified" % key)

		properties_.append((key, mdb.tools.parse_value_and_modifier(value)))

	return properties_

p.CRISPR_properties = parse_properties(p.CRISPR_properties) if (p.CRISPR_properties) else []
p.DR_properties = parse_properties(p.DR_properties) if (p.DR_properties) else []
p.spacer_properties = parse_properties(p.spacer_properties) if (p.spacer_properties) else []

try:
	collection = mdb.Collection.find_one({"name": p.sequences_collection_name})
	if (collection == None):
		error("Unknown collection '%s'" % p.sequences_collection_name)

	# creation of the structure
	CRISPRs = mdb.Collection.find_one({"name": p.CRISPRs_collection_name})
	DRs = mdb.Collection.find_one({"name": p.repeats_collection_name})
	Spacers = mdb.Collection.find_one({"name": p.spacers_collection_name})

	if (CRISPRs == None):
		print "Creating CRISPRs collection '%s'" % p.CRISPRs_collection_name
		CRISPRs = mdb.Collection({"name": p.CRISPRs_collection_name, "class": "CRISPRs"})
		CRISPRs.commit()

	if (DRs == None):
		print "Creating direct repeats collection '%s'" % p.repeats_collection_name
		DRs = mdb.Collection({"name": p.repeats_collection_name, "class": "direct repeats"})
		DRs.add_to_collection(CRISPRs)
		DRs.commit()

	if (Spacers == None):
		print "Creating spacers collection '%s'" % p.spacers_collection_name
		Spacers = mdb.Collection({"name": p.spacers_collection_name, "class": "spacers"})
		Spacers.add_to_collection(CRISPRs)
		Spacers.commit()

except mdb.errors.DBConnectionError as msg:
	error(msg)

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

def clean_sequence (sequence):
	sequence_ = ''
	for c in sequence:
		if (c.lower() in "atgc"):
			sequence_ += c
	return sequence_

SEQUENCE = "[ATGCNw]+"

def parser (fn):
	sequence_id, entry = None, []

	DR_spacer_pattern = re.compile("^(%s)\s+(%s)\s+[0-9]+$" % (SEQUENCE, SEQUENCE))
	spacer_pattern = re.compile("^(%s)\s+[0-9]+$" % SEQUENCE)
	DR_pattern = re.compile("^(%s)$" % SEQUENCE)

	def send():
		CRISPR, DRs, spacers = [], {}, []

		for line in entry:
			DR_spacer_match = DR_spacer_pattern.match(line)
			if (DR_spacer_match != None):
				DR, spacer = [clean_sequence(sequence) for sequence in DR_spacer_match.groups()]

				CRISPR.append(DR)
				CRISPR.append(spacer)
				DRs[DR] = True
				spacers.append(spacer)
				continue

			spacer_match = spacer_pattern.match(line)
			if (spacer_match != None):
				spacer = clean_sequence(spacer_match.group(1))

				CRISPR.append(spacer)
				spacers.append(spacer)
				continue

			DR_match = DR_pattern.match(line)
			if (DR_match != None):
				DR = clean_sequence(DR_match.group(1))

				CRISPR.append(DR)
				DRs[DR] = True
				continue

			error("Malformed input: '%s'" % line)

		return sequence_id, ''.join(CRISPR), sorted(DRs.keys()), spacers

	fh = open(fn, 'rU')

	while True:
		line = fh.readline()
		if (line == ''):
			break

		line = line.strip()
		if (line == ''):
			continue

		if (line.startswith('>')):
			if (sequence_id != None):
				yield send()

			try:
				sequence_id = get_sequence_id(line[1:])
			except Exception as msg:
				error("Unable to retrieve the sequence identifier. Message was \"%s\"; original identifier was \"%s\"" % (msg, line[1:]))

			sequence_id = Patch.get(sequence_id, sequence_id)
			entry = []
		else:
			entry.append(line)

	if (sequence_id != None):
		yield send()

print "Importing '%s' ..." % p.input_fn

print "  validating sequences ..."

Sequences = {}
for (sequence_id, CRISPR_sequence, DR_sequences, spacer_sequences) in parser(p.input_fn):
	if (sequence_id in Sequences):
		error("Duplicate sequence '%s' in input" % sequence_id)

	sequence = list(collection.list_sequences({"name": sequence_id}))

	if (len(sequence) == 0):
		error("Unknown sequence '%s'" % sequence_id)

	if (len(sequence) > 1):
		error("Duplicate sequence '%s'" % sequence_id)

	Sequences[sequence_id] = sequence[0]

if (len(Sequences) == 0):
	error("No sequence in the input")

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

print "  importing annotations ..."

class ProgressBar:
	def __init__ (self, upper = None):
		self.__min = 0.0
		self.__max = upper + 0.0

	def display (self, value):
		f = (value - self.__min) / (self.__max - self.__min) # fraction
		p = 100 * f # percentage
		s = int(round(80 * f)) # bar size

		sys.stdout.write(' ' * 2 + ('.' * s) + " %4.2f%%\r" % p)
		sys.stdout.flush()

	def clear (self):
		sys.stdout.write(' ' * (2 + 80 + 8) + "\r")
		sys.stdout.flush()

show_pb = (not p.dry_run)

N, n = len(Sequences), 0
pb = ProgressBar(N)

if (show_pb):
	pb = ProgressBar(N)

def generate_name (sequence):
	return hashlib.md5(sequence.upper()).hexdigest()

generate_CRISPR_name = generate_spacer_name = generate_DR_name = generate_name

"""
def generate_spacer_name (sequence):
	# find the reverse complement
	sequence_ = ''
	for c in sequence.upper():
		sequence_ += {
			'A': 'T',
			'T': 'A',
			'G': 'C',
			'C': 'G',
		}[c]

	return sorted((sequence.upper(), sequence_))[0]

generate_DR_name = generate_spacer_name
"""

# check if a sequence with a given name already exists in a collection
def exists (collection, name):
	result = list(collection.list_sequences({"name": name}))
	n = len(result)

	if (n == 0):
		return None
	elif (n == 1):
		return result[0]
	else:
		error("Duplicate sequence '%s' in collection %s" % (name, collection))

def declare_CRISPR (sequence):
	name = generate_CRISPR_name(sequence)
	CRISPR = exists(CRISPRs, name)

	if (CRISPR == None):
		CRISPR = mdb.Sequence({"name": name, "sequence": sequence, "class": "CRISPR"})
		for (key, value) in p.CRISPR_properties:
			CRISPR[key] = value

		CRISPR.add_to_collection(CRISPRs)
		CRISPR.commit()

	return CRISPR

def declare_DR (sequence):
	name = generate_DR_name(sequence)
	DR = exists(DRs, name)

	if (DR == None):
		DR = mdb.Sequence({"name": name, "sequence": sequence, "class": "direct repeat"})
		for (key, value) in p.DR_properties:
			DR[key] = value

		DR.add_to_collection(DRs)
		DR.commit()

	return DR

def declare_spacer (sequence):
	name = generate_spacer_name(sequence)
	spacer = exists(Spacers, name)

	if (spacer == None):
		spacer = mdb.Sequence({"name": name, "sequence": sequence, "class": "spacer"})
		for (key, value) in p.spacer_properties:
			spacer[key] = value

		spacer.add_to_collection(Spacers)
		spacer.commit()

	return spacer

def connect (source, target, relationship):
	try:
		source.relate_to_sequence(target, relationship)
	
	except mdb.errors.DuplicateObjectError as msg:
		pass

for (sequence_id, CRISPR_sequence, DR_sequences, spacer_sequences) in parser(p.input_fn):
	sequence = Sequences[sequence_id]

	r = {
		"type": "part-of",

		"run": {
			"date": {"year": p.date[0], "month": p.date[1], "day": p.date[2]},
			"algorithm": {
				"name": "CRISPRfinder"
			}
		}
	}

	if (not p.dry_run):
		# declare the CRISPR, if not already done
		CRISPR = declare_CRISPR(CRISPR_sequence)

		seen_DR = {}
		for DR_sequence in DR_sequences:
			DR = declare_DR(DR_sequence)
			if (DR in seen_DR):
				continue

			# declare the DR -> CRISPR|sequence relationships
			connect(DR, CRISPR, r)
			connect(DR, sequence, r)
			DR.commit()
			seen_DR[DR] = True

		seen_spacer = {}
		for spacer_sequence in spacer_sequences:
			spacer = declare_spacer(spacer_sequence)
			if (spacer in seen_spacer):
				continue

			# declare the spacer -> CRISPR|sequence relationships
			connect(spacer, CRISPR, r)
			connect(spacer, sequence, r)
			spacer.commit()
			seen_spacer[spacer] = True

		# declare the CRISPR -> sequence relationship
		connect(CRISPR, sequence, r)
		CRISPR.commit()

		if (p.display_progress_bar):
			pb.display(n)

	else:
		print "    CRISPR '%s...' (%s spacer%s) added to sequence '%s'" % (
			CRISPR_sequence[:20],
			len(spacer_sequences),
			{True: 's', False: ''}[len(spacer_sequences) > 1],
			sequence_id
		)

	n += 1

if (show_pb) and (p.display_progress_bar):
	pb.clear()

print "  %s sequence%s annotated." % (n, {True: 's', False: ''}[n > 1])

if (p.dry_run):
	print "(dry run)"
