#!/usr/bin/env python

"""
A helper script for many matgendb functions.
"""

__author__ = "Shyue Ping Ong"
__copyright__ = "Copyright 2012, The Materials Project"
__version__ = "1.2"
__maintainer__ = "Shyue Ping Ong"
__email__ = "shyue@mit.edu"
__date__ = "Dec 1, 2012"

import os
import datetime
import logging
import multiprocessing
import json
import sys
import webbrowser
import time
import argparse

from pymongo import Connection, ASCENDING

from pymatgen.apps.borg.queen import BorgQueen

from matgendb.query_engine import QueryEngine
from matgendb.creator import VaspToDbTaskDrone
from matgendb.util import get_settings, DEFAULT_SETTINGS, MongoJSONEncoder

_log = logging.getLogger("mg")  # parent

def init_db(args):
    d = DEFAULT_SETTINGS
    doc = {}
    print("Please supply the following configuration values")
    print("(press Enter if you want to accept the defaults)\n")
    for k, v in d:
        val = raw_input("Enter {} (default: {}) : ".format(k, v))
        doc[k] = val if val else v
    doc["port"] = int(doc["port"])  # enforce the port as an int
    with open(args.config_file, "w") as f:
        json.dump(doc, f, indent=4, sort_keys=True)
    print("\nConfiguration written to {}!".format(args.config_file))


def update_db(args):
    FORMAT = "%(relativeCreated)d msecs : %(message)s"

    if args.logfile:
        logging.basicConfig(level=logging.INFO, format=FORMAT,
                            filename=args.logfile[0])
    else:
        logging.basicConfig(level=logging.INFO, format=FORMAT)

    d = get_settings(args.config_file)

    _log.info("Db insertion started at {}.".format(datetime.datetime.now()))
    additional_fields = {"author": args.author, "tags": args.tag}
    drone = VaspToDbTaskDrone(
        host=d["host"], port=d["port"],  database=d["database"],
        user=d["admin_user"], password=d["admin_password"],
        parse_dos=args.parse_dos,
        collection=d["collection"], update_duplicates=args.force_update_dupes,
        additional_fields=additional_fields, mapi_key=d.get("mapi_key", None))
    ncpus = multiprocessing.cpu_count() if not args.ncpus else args.ncpus
    _log.info("Using {} cpus...".format(ncpus))
    queen = BorgQueen(drone, number_of_drones=ncpus)
    queen.parallel_assimilate(args.directory)
    tids = map(int, filter(lambda x: x, queen.get_data()))
    _log.info("Db upate completed at {}.".format(datetime.datetime.now()))
    _log.info("{} new task ids inserted.".format(len(tids)))


def optimize_indexes(args):
    d = get_settings(args.config_file)
    c = Connection(d["host"], d["port"])
    db = c[d["database"]]
    db.authenticate(d["admin_user"], d["admin_password"])
    coll = db[d["collection"]]
    coll.drop_indexes()
    coll.ensure_index('task_id', unique=True)
    for key in ['unit_cell_formula', 'reduced_cell_formula', 'chemsys',
                'nsites', 'pretty_formula', 'analysis.e_above_hull',
                "icsd_ids"]:
        print "Building {} index".format(key)
        coll.ensure_index(key)
    print "Building nelements and elements compound index"
    compound_index = [('nelements', ASCENDING), ('elements', ASCENDING)]
    coll.ensure_index(compound_index)
    coll.ensure_index(compound_index)


def query_db(args):
    from prettytable import PrettyTable
    d = get_settings(args.config_file)
    qe = QueryEngine(host=d["host"], port=d["port"], database=d["database"],
                     user=d["readonly_user"], password=d["readonly_password"],
                     collection=d["collection"],
                     aliases_config=d.get("aliases_config", None))
    criteria = None
    if args.criteria:
        try:
            criteria = json.loads(args.criteria)
        except ValueError:
            print("Criteria {} is not a valid JSON string!".format(
                args.criteria))
            sys.exit(-1)

    # TODO: document this 'feature' --dang 4/4/2013
    is_a_file = lambda s: len(s) == 1 and s[0].startswith(':')
    if is_a_file(args.properties):
        with open(args.properties[0][1:], 'rb') as f:
            props = [s.strip() for s in f]
    else:
        props = args.properties

    if args.dump_json:
        for r in qe.query(properties=props, criteria=criteria):
            print(json.dumps(r, cls=MongoJSONEncoder))
    else:
        t = PrettyTable(props)
        t.float_format = "4.4"
        for r in qe.query(properties=props, criteria=criteria):
            t.add_row([r[p] for p in props])
        print(t)


def run_server(args):
    os.environ.setdefault("DJANGO_SETTINGS_MODULE", "matgendb.webui.settings")
    os.environ["MGDB_CONFIG"] = json.dumps(get_settings(args.config_file))
    from django.core.management import call_command
    from multiprocessing import Process
    p1 = Process(target=call_command,
                 args=("runserver",  "{}:{}".format(args.host, args.port)))
    p1.start()
    if not args.nobrowser:
        time.sleep(2)
        webbrowser.open("http://{}:{}".format(args.host, args.port))
    p1.join()


class StatsCommands(object):
    """Encapsulate statistics commands.
    """
    #: Constants for names of format options
    FORMAT_YAML, FORMAT_CLEAN = "yaml", "simple"

    def __init__(self, query_engine, output_format=None, latest_prop=None, **extra_kw):
        """Constructor.

        :param query_engine: Connection to Mongo
        :type query_engine: matgendb.query_engine.QueryEngine
        """
        self._qe = query_engine
        self._db, self._coll = query_engine.db, query_engine.collection
        self.indent = ' ' * 4
        # Keywords
        self._latest = latest_prop
        if output_format == self.FORMAT_YAML:
            self._format = self.yaml_format
            self._show_header = True
        elif output_format == self.FORMAT_CLEAN:
            self._format = self.clean_format
            self._show_header = False

    def header(self):
        if not self._show_header:
            return ""
        s = self._format((("Database", self._db.name),
                          ("Collection", self._coll.name),
                          ("Values", "")))
        return s

    def stat(self, name):
        func = 'stat_{}'.format(name)
        return getattr(self, func)()

    def stat_latest(self):
        cur = self._coll.find({}, [self._latest])\
                        .sort(self._latest, -1)\
                        .limit(1)
        try:
            value = "{}".format(cur[0][self._latest])
        except IndexError:
            _log.error("Cannot get latest value: no records with field '{}'"
                       .format(self._latest))
            return ""
        return self._format([("Latest", value)], depth=1)

    def stat_count(self):
        value = "{:d}".format(self._coll.count())
        return self._format([("Count", value)], depth=1)

    def yaml_format(self, data, depth=0):
        rows, ind = [], self.indent * depth
        for k, v in data:
            rows.append("{i}{k}: {v}".format(i=ind, k=k, v=v))
        return '\n'.join(rows) + "\n"

    def clean_format(self, data, depth=0):
        rows = ["{} {}".format(d[0].lower(), d[1]) for d in data]
        return '\n'.join(rows) + "\n"


def db_stats(args):
    """Database and/or collection statistics.
    """
    # Init db.
    cfg = get_settings(args.config_file)
    normalize_userpass(cfg)
    qe = QueryEngine(**cfg)

    # Copy stat args and keywords.
    stats, kw = {}, {}
    for k, v in vars(args).iteritems():
        if k.startswith("st_"):
            stats[k[3:]] = v
        else:
            kw[k] = v

    # Process 'show all'.
    if stats['showall']:
        for k in stats.iterkeys():
            stats[k] = True
    del stats['showall']

    # Output destination.
    w = sys.stdout

    # Run commands.
    cmd = StatsCommands(qe, **kw)
    w.write(cmd.header())
    for k in stats.iterkeys():
        if stats[k]:
            w.write(cmd.stat(k))


def normalize_userpass(cfg):
    """In DB conn. config, normalize user/password from readonly and admin prefixes.
    In the end, there will be only keys 'user' and 'password'.
    """
    for pfx in 'readonly', 'admin':  # in reverse order of priority, to overwrite
        if (pfx + '_user') in cfg and (pfx + '_password') in cfg:
            cfg[QueryEngine.USER_KEY] = cfg[pfx + '_user']
            cfg[QueryEngine.PASSWORD_KEY] = cfg[pfx + '_password']
            del cfg[pfx + '_user']
            del cfg[pfx + '_password']


def init_logging(args):
    """Initialize verbosity
    """
    _log.propagate = False
    hndlr = logging.StreamHandler()
    hndlr.setFormatter(logging.Formatter("[%(levelname)-6s] %(asctime)s %(name)s :: %(message)s"))
    _log.addHandler(hndlr)
    if args.quiet:
        lvl = logging.CRITICAL
    else:
        # Level:  default      -v            -vv
        lvl = (logging.WARN, logging.INFO, logging.DEBUG)[min(args.vb, 2)]
    _log.setLevel(lvl)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="""
    mgdb is a complete command line db management script for pymatgen-db. It
    provides the facility to insert vasp runs, perform queries, and run a web
    server for exploring databases that you create. Type mgdb -h to see the
    various options.

    Author: Shyue Ping Ong
    Version: 3.0
    Last updated: Mar 23 2013""")

    # Parents for all subparsers.
    parent_vb = argparse.ArgumentParser(add_help=False)
    parent_vb.add_argument('--quiet', '-q', dest='quiet', action="store_true", default=False,
                           help="Minimal verbosity.")
    parent_vb.add_argument('--verbose', '-v', dest='vb', action="count", default=0,
                           help="Print more verbose messages to standard error. Repeatable. (default=ERROR)")
    parent_cfg = argparse.ArgumentParser(add_help=False)
    parent_cfg.add_argument("-c", "--config", dest="config_file", type=str,
                            help="Config file to use. Generate one using mgdb "
                            "init --config filename.json if necessary. "
                            "Otherwise, the code searches for a db.json. If"
                            "none is found, an no-authentication "
                            "localhost:27017/vasp database and tasks "
                            "collection is assumed.")

    # Init for all subparsers.
    subparsers = parser.add_subparsers()

    # The 'init' subcommand.
    pinit = subparsers.add_parser("init", help="Initialization tools.", parents=[parent_vb])
    pinit.add_argument("-c", "--config", dest="config_file", type=str,
                       nargs='?', default="db.json",
                       help="Creates an db config file for the database. "
                            "Default filename is db.json.")
    pinit.set_defaults(func=init_db)

    popt = subparsers.add_parser("optimize", help="Optimization tools.")

    popt.add_argument("-c", "--config", dest="config_file", type=str,
                       nargs='?', default="db.json",
                       help="Creates an db config file for the database. "
                            "Default filename is db.json.")
    popt.set_defaults(func=optimize_indexes)

    # The 'insert' subcommand.
    pinsert = subparsers.add_parser("insert", help="Insert vasp runs.",
                                    parents=[parent_vb, parent_cfg])
    pinsert.add_argument("directory", metavar="directory", type=str,
                         default=".", help="Root directory for runs.")
    pinsert.add_argument("-l", "--logfile", dest="logfile", type=str,
                         help="File to log db insertion. Defaults to stdout.")
    pinsert.add_argument("-t", "--tag", dest="tag", type=str, nargs=1,
                         default=[],
                         help="Tag your runs for easier search."
                              " Accepts multiple tags")
    pinsert.add_argument("-f", "--force", dest="force_update_dupes",
                         action="store_true",
                         help="Force update duplicates. This forces the "
                              "analyzer to reanalyze already inserted data.")
    pinsert.add_argument("-d", "--parse_dos", dest="parse_dos",
                         action="store_true",
                         help="Whether to parse the dos.")
    pinsert.add_argument("-a", "--author", dest="author", type=str, nargs=1,
                         default=None,
                         help="Enter a *unique* author field so that you can "
                              "trace back what you ran.")
    pinsert.add_argument("-n", "--ncpus", dest="ncpus", type=int,
                         default=None,
                         help="Number of CPUs to use in inserting. If "
                              "not specified, multiprocessing will use "
                              "the number of cpus detected.")
    pinsert.set_defaults(func=update_db)

    # The 'query' subcommand.
    pquery = subparsers.add_parser("query",
                                   help="Query tools. Requires the "
                                        "use of pretty_table.",
                                   parents=[parent_vb, parent_cfg])
    pquery.add_argument("--crit", dest="criteria", type=str, default=None,
                        help="Query criteria in typical json format. E.g., "
                             "{\"task_id\": 1}.")
    pquery.add_argument("--props", dest="properties", type=str, default=[],
                        nargs='+', required=True,
                        help="Desired properties. Repeatable. E.g., pretty_formula, "
                             "task_id, energy...")
    pquery.add_argument("--dump", dest="dump_json", action='store_true', default=False,
                        help="Simply dump results to JSON instead of a tabular view")
    pquery.set_defaults(func=query_db)

    # The 'stats' subcommand.
    pstats = subparsers.add_parser("stats", help="Database and/or collection statistics.",
                                   parents=[parent_vb, parent_cfg])
    pstats.add_argument("--count", help="Show count of records", dest="st_count", action="store_true")
    pstats.add_argument("--latest", help="Show time of latest record", dest="st_latest", action="store_true")
    #pstats.add_argument("--size", help="Show min/max/avg/total record sizes", dest="", action="store_true")
    pstats.add_argument("--all", help="Show all stats", dest="st_showall", action="store_true")
    pstats.add_argument("--latest-prop", help="Property to use for latest record (default='updated_at')",
                        default="updated_at", dest="latest_prop")
    pstats.add_argument("-f", "--format", help="Output format", choices=[StatsCommands.FORMAT_YAML,
                                                                         StatsCommands.FORMAT_CLEAN],
                        default="yaml", dest="output_format")
    pstats.set_defaults(func=db_stats)

    # The 'runserver' subcommand.
    pserve = subparsers.add_parser("runserver",
                                   help="Run a server to the database.",
                                   parents=[parent_vb, parent_cfg])
    pserve.add_argument("--nobrowser", dest="nobrowser", action="store_true",
                        help="Don't automatically open a browser window "
                             "(e.g. for shell-based servers)")
    pserve.add_argument("--port", dest="port", type=int, default=8000,
                        help="Port to run the server on (default: 8000)")
    pserve.add_argument("--host", dest="host", type=str, default="127.0.0.1",
                        help="Host to run the server on (default: 127.0.0.1)")
    pserve.set_defaults(func=run_server)

    # Parse args
    args = parser.parse_args()

    init_logging(args)

    # Run appropriate subparser function.
    args.func(args)
