#!/usr/bin/env python

"""
This is a master vasp running script to perform various combinations of VASP
runs.
"""

from __future__ import division

__author__ = "Shyue Ping Ong"
__version__ = "0.5"
__maintainer__ = "Shyue Ping Ong"
__email__ = "ongsp@ucsd.edu"
__status__ = "Beta"
__date__ = "12/31/13"

import logging
import sys

from custodian.custodian import Custodian
from custodian.vasp.jobs import VaspJob
from pymatgen.io.vaspio.vasp_input import VaspInput, Incar


def load_class(mod, name):
    mod = __import__(mod, globals(), locals(), [name], 0)
    return getattr(mod, name)


def do_run(args):
    FORMAT = '%(asctime)s %(message)s'
    logging.basicConfig(format=FORMAT, level=logging.INFO, filename="run.log")
    logging.info("Handlers used are %s" % args.handlers)
    handlers = [load_class("custodian.vasp.handlers", n)() for n in
                args.handlers]
    vasp_command = args.command.split()

    #save initial INCAR for rampU runs
    n_ramp_u = args.jobs.count('rampU')
    ramps = 0
    if n_ramp_u:
        incar = Incar.from_file('INCAR')
        ldauu = incar['LDAUU']
        ldauj = incar['LDAUJ']

    jobs = []

    njobs = len(args.jobs)
    for i, job_type in enumerate(args.jobs):
        if job_type not in ("relax", "static", "rampU"):
            print("Unsupported job type: {}".format(job_type))
            sys.exit(-1)
        final = False if i != njobs - 1 else True
        suffix = ".{}{}".format(job_type, i + 1)
        settings = []
        backup = True if i == 0 else False
        copy_magmom = False
        if i > 0 and job_type != "static":
            settings.extend([
                {"dict": "INCAR",
                 "action": {"_set": {"ISTART": 1}}},
                {"filename": "CONTCAR",
                 "action": {"_file_copy": {"dest": "POSCAR"}}}])
        elif job_type == "static":
            vinput = VaspInput.from_directory(".")
            m = [i * args.static_kpoint for i in vinput["KPOINTS"].kpts[0]]
            settings.extend([{"dict": "INCAR",
                 "action": {"_set": {"ISTART": 1, "NSW": 0}}},
                {'dict': 'KPOINTS',
                 'action': {'_set': {'kpoints': [m]}}},
                {"filename": "CONTCAR",
                 "action": {"_file_copy": {"dest": "POSCAR"}}}])
        if job_type == "rampU":
            f = ramps / (n_ramp_u - 1)
            settings.append({"dict": "INCAR",
                 "action": {"_set": {"LDAUJ": [j * f for j in ldauj], 
                                     "LDAUU": [u * f for u in ldauu]}}})
            copy_magmom = True
            ramps += 1

        jobs.append(VaspJob(vasp_command, final=final, suffix=suffix,
                            backup=backup, settings_override=settings,
                            copy_magmom=copy_magmom))

    c = Custodian(handlers, jobs, max_errors=args.max_errors, scratch_dir=args.scratch,
                  gzipped_output=args.gzip, checkpoint=True)
    c.run()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="""
    run_vasp is a master script to perform various kinds of VASP runs.
    """,
                                     epilog="""
    Author: Shyue Ping Ong
    Version: {}
    Last updated: {}""".format(__version__, __date__))

    parser.add_argument(
        "-c", "--command", dest="command", nargs="?",
        default="pvasp", type=str,
        help="VASP command. Defaults to pvasp. If you are using mpirun, "
             "set this to something like \"mpirun pvasp\".")

    parser.add_argument(
        "-z", "--gzip", dest="gzip", action="store_true",
        help="Add this option to gzip the final output. Do not gzip if you "
             "are going to perform an additional static run."
    )

    parser.add_argument(
        "-s", "--scratch", dest="scratch", nargs="?",
        default=None, type=str,
        help="Scratch directory to perform run in. Specify the root scratch "
             "directory as the code will automatically create a temporary "
             "subdirectory to run the job.")

    parser.add_argument(
        "-ks", "--kpoint-static", dest="static_kpoint", nargs="?",
        default=1, type=int,
        help="The multiplier to use for the KPOINTS of a static run (if "
             "any). For example, setting this to 2 means that if your "
             "original run was done using a k-point grid of 2x3x3, "
             "the static run will be done with a k-point grid of 4x6x6. This "
             "defaults to 1, i.e., static runs are performed with the same "
             "k-point grid as relaxation runs."
    )

    parser.add_argument(
        "-me",  "--max-errors", dest="max_errors", nargs="?",
        default=10, type=int,
        help="Maximum number of errors to allow before quitting")

    parser.add_argument(
        "-hd", "--handlers", dest="handlers", nargs="+",
        default=["VaspErrorHandler", "MeshSymmetryErrorHandler",
                 "UnconvergedErrorHandler", "NonConvergingErrorHandler",
                 "PotimErrorHandler", "BadVasprunXMLHandler"], type=str,
        help="The ErrorHandlers to use specified as string class names. Note "
             "that the error handlers will be initialized with no args, i.e.,"
             "default args will be assumed."
    )

    parser.add_argument("jobs", metavar="jobs", type=str, nargs='+',
                        default=["relax", "relax"],
                        help="Jobs to execute. Only sequences of relax, "
                             "static, and rampU are supported at the moment. "
                             "For example, \"relax relax static\" will run a "
                             "double relaxation followed by a static run.")

    args = parser.parse_args()
    do_run(args)