#!/usr/bin/env python

import optparse, sys, os, pprint

p = optparse.OptionParser(description = """Part of the MetagenomeDB toolkit.
Annotate objects in the database. Annotations are provided as either JSON- or
CSV-formatted files.""")

p.add_option("-i", "--input", dest = "input_fn", metavar = "FILENAME",
	help = "Name of the file containing the annotations, or '-' to read from the standard input (mandatory).")

p.add_option("-f", "--format", dest = "input_format", choices = ("json", "csv"), metavar = "STRING", default = "csv",
	help = "Format of the input file, either 'json' or 'csv' (optional). Default: %default")

p.add_option("--ignore-unknown", dest = "ignore_unknown", action = "store_true", default = False,
	help = "If set, ignore unknown objects.")

p.add_option("-v", "--verbose", dest = "verbose", action = "store_true", default = False)
p.add_option("--dry-run", dest = "dry_run", action = "store_true", default = False)

g = optparse.OptionGroup(p, "Connection")

g.add_option("--host", dest = "connection_host", metavar = "HOSTNAME", default = "localhost",
	help = "Host name or IP address of the MongoDB server (optional). Default: %default")

g.add_option("--port", dest = "connection_port", metavar = "INTEGER", default = 27017,
	help = "Port of the MongoDB server (optional). Default: %default")

g.add_option("--db", dest = "connection_db", metavar = "STRING", default = "MetagenomeDB",
	help = "Name of the database in the MongoDB server (optional). Default: '%default'")

g.add_option("--user", dest = "connection_user", metavar = "STRING", default = '',
	help = "User for the MongoDB server connection (optional). Default: '%default'")

g.add_option("--password", dest = "connection_password", metavar = "STRING", default = '',
	help = "Password for the MongoDB server connection (optional). Default: '%default'")

p.add_option_group(g)

(p, a) = p.parse_args()

def error (msg):
	if str(msg).endswith('.'):
		msg = str(msg)[:-1]
	print >>sys.stderr, "ERROR: %s." % msg
	sys.exit(1)

if (p.input_fn == None):
	error("An input file must be provided")

if (not os.path.exists(p.input_fn)):
	error("File '%s' does not exists" % p.input_fn)

#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

import MetagenomeDB as mdb

if (p.verbose):
	mdb.max_verbosity()

if (p.connection_host or p.connection_port or p.connection_db or p.connection_user or p.connection_password):
	try:
		mdb.connect(p.connection_host, p.connection_port, p.connection_db, p.connection_user, p.connection_password)
	except Exception as msg:
		error(msg)

class NotFound (Exception):
	pass

def key_filter (key):
	for k in key:
		if (k.startswith('_')):
			return False

	return True

def pull (map, key):
	value, command = map[key]
	del map[key]
	return value

n_annotated = 0

try:
	for entry in mdb.tools.parser(p.input_fn, p.input_format):
		try:
			object_type = pull(entry, "_type").lower()

			# _type=sequence, _collection='...', name='...', ...
			if (object_type == "sequence"):
				collection_name, sequence_name = pull(entry, "_collection"), pull(entry, "name")

				# we first list all sequences having this name,
				candidates = mdb.Sequence.find({"name": sequence_name})

				# then we filter out those that are not linked to this collection
				candidates_ = filter(lambda s: s.count_collections({"name": collection_name}) > 0, candidates)

				if (len(candidates_) == 0):
					raise NotFound("Unknown sequence '%s' in collection '%s'" % (sequence_name, collection_name))

				if (len(candidates_) > 1):
					raise Exception("Duplicate sequence '%s' in collection '%s'" % (sequence_name, collection_name))

				object = candidates_[0]
				object_name = "sequence '%s' in collection '%s'" % (sequence_name, collection_name)

			# _type=collection, name='...', ...
			elif (object_type == "collection"):
				collection_name = pull(entry, "name")

				candidate = mdb.Collection.find_one({"name": collection_name})

				if (candidate == None):
					raise NotFound("Unknown collection '%s'" % collection_name)

				object = candidate
				object_name = "collection '%s'" % collection_name

			else:
				raise Exception("Unknown object type '%s'" % object_type)

			actions = []
			for key, (value, command) in mdb.tree.items(entry):
				# We ignore any key hierarchy which contains a
				# special key (i.e., key starting with a '_').
				# This would be caught by the API, but it is
				# easier to check this at this stage.
				if (not key_filter(key)):
					print >>sys.stderr, "WARNING: Key '%s' is invalid and was ignored." % key
					continue

				if (p.dry_run):
					key = '.'.join(key)

					if (command == mdb.tools.REPLACE):
						actions.append("SET \"%s\" as value for property '%s'" % (value, key))

					elif (command == mdb.tools.APPEND):
						actions.append("APPEND \"%s\" to property '%s'" % (value, key))

					elif (command == mdb.tools.APPEND_IF_UNIQUE):
						actions.append("APPEND \"%s\" (if unique) to property '%s'" % (value, key))

					elif (command == mdb.tools.REMOVE):
						actions.append("REMOVE property '%s'" % key)
					
				else:
					if (command == mdb.tools.REPLACE):
						object[key] = value

					elif (command == mdb.tools.APPEND) or (command == mdb.tools.APPEND_IF_UNIQUE):
						value_ = object.get_property(key, [])
						if (type(value) != list):
							value = [value]

						for v in value:
							if (command == mdb.tools.APPEND_IF_UNIQUE) and (v in value_):
								continue

							value_.append(v)

						object[key] = value_

					elif (command == mdb.tools.REMOVE):
						del object[key]

			if (not p.dry_run) and object.is_committed():
				print >>sys.stderr, "WARNING: %s has not been modified." % object_name

			else:
				n_annotated += 1

				if (p.dry_run):
					print "annotate: %s" % object_name
					for line in actions:
						print "  %s" % line
				else:
					object.commit()

		except mdb.errors.ConnectionError as msg:
			error(str(msg))

		except NotFound as msg:
			if (p.ignore_unknown):
				print >>sys.stderr, "WARNING: %s" % msg
			else:
				error(msg)

		except Exception as msg:
			error("Invalid entry: %s. Entry was:\n %s" % (msg, pprint.pformat(entry)))

	print "%s object%s annotated." % (n_annotated, {True: 's', False: ''}[n_annotated > 1])

	if (p.dry_run):
		print "(dry run)"

except Exception as msg:
	error("Error when processing the input: %s." % msg)
