#!/home/sjk648/virtualenvs/production/bin/python
# -*- coding: utf-8 -*-

import sys
import argparse
import os.path
import itertools
from qmpy import *
from qmpy.computing.manager import TaskManager, JobManager

def main():
    '''Handles command line args, parses them to understand what sub-scripts to
    run and handles displaying the results.'''

    parser = argparse.ArgumentParser(prog='oqmd',
            formatter_class=argparse.RawTextHelpFormatter,
            description='''
Command line access to OQMD functionality. Allows a user to obtain information
about database entries, do thermodynamic analyisis or request calculations.
''')

    parser.add_argument('module', default='query',
            help='Select the general functionality group you want to access')

    parser.add_argument('--task','-T',
            nargs='*', default='',
            help='''What do you want to do, a given task may have a different
            meaning depending on the selected module.''')

    parser.add_argument('--formula','-F',
            nargs='*', default='',
            help='''A representation of a composition or compositions. 
            Fe2O3 -> {'Fe':2, 'O':3}
            FeO -> {'Fe':1, 'O':1}
            {Fe,Ni}O -> [{'Fe':1, 'O':1}, {'Ni':1, 'O':1}]''')

    parser.add_argument('--search','-S',
            nargs='*', default='',
            help='''
            Database searching.
            ''')

    parser.add_argument('--include', '-I',
            nargs='*', default='oqmd',
            help='')
    
    parser.add_argument('--data',
            nargs='*', default='',
            help='')

    runner, args = parser.parse_known_args(sys.argv[1:])
    if not runner.task:
        if runner.module == 'gclp':
            runner.task = ['single_point']
        else:
            return parser.parse_args(['--help'])

    runner.args = args

#==============================================================================#
# module: config
#==============================================================================#
    if runner.task[0] == 'add_host':
        Host.interactive_create()
    if runner.task[0] == 'add_user':
        User.interactive_create()
    if runner.task[0] == 'add_allocation':
        Allocation.interactive_create()
    if runner.task[0] == 'add_project':
        Project.interactive_create()

#==============================================================================#
# module: database
#==============================================================================#

    if runner.task[0] == 'describe':
        raise NotImplementedError

    if runner.task[0] == 'add_cifs':
        if runner.path is None:
            runner.path = '/home/sjk648/storage/add_cifs'

        cifs = os.listdir(runner.path)
        print 'Adding cifs from', runner.path, '...'
        for cif in cifs:
            cifnum = cif.replace('.cif','')
            print '   - ', cifnum,
            cifpath = '/home/sjk648/libraries/icsd/%s' % cifnum
            if Entry.objects.filter(path=cifpath).exists():
                continue
            f = open('/home/sjk648/storage/add_cifs/%s' % cif)
            if len(f.readlines()) > 300:
                print 'is really long'
                continue
            print 'is being added'
            os.system('mkdir %s &> /dev/null' % cifpath)
            os.system('cp /home/sjk648/storage/add_cifs/%s %s/%s' % (
                cif, cifpath, cif))
            entry = Entry.create(cifpath+'/'+cif, project='icsd')
            entry.save()

    #======================================================================#
    if runner.task[0] == 'add_structure':
        if not kwargs.get('project', False):
            runner.kwargs['project'] = raw_input('project: ')
        if not kwargs.get('keywords', False):
            runner.kwargs['keywords'] = raw_input('keywords: ').split()
        path = os.path.abspath(runner.task[1])
        try:
            entry = Entry.create(path,
                    keywords=runner.kwargs['keywords'],
                    project=runner.kwargs['project'])
            entry.save()
        except:
            print 'Failed to add structure file: %s' % path
            print 'Please verify that it is a valid structure file'

    #======================================================================#
    if runner.task[0] == 'add_directory':
        if 'project' not in runner.kwargs:
            runner.kwargs['project'] = raw_input('project: ')
        if 'keywords' not in runner.kwargs:
            runner.kwargs['keywords'] = raw_input('keywords: ')
        path = os.path.abspath(runner.task[1])
        for file in os.listdir(path):
            try:
                entry = Entry.create(path,
                    keywords=runner.kwargs['project'],
                    project=runner.kwargs['keywords'])
                entry.save()
            except:
                print 'Failed to add structure file: %s' % path
                print 'Please verify that it is a valid structure file'

#==============================================================================#
# queue
#==============================================================================#
    if runner.module == 'daemon':
        ts = TaskManager('/tmp/oqmd-tasksever.pid')
        js = JobManager('/tmp/oqmd-jobsever.pid')
        if runner.task[0] == 'start':
            ts.start()
            js.start()
        elif runner.task[1] == 'stop':
            ts.stop()
            js.stop()
        elif runner.task[1] == 'restart':
            ts.restart()
            js.restart()
        else:
            print 'Unknown command'
            sys.exit(2)

    #======================================================================#
    if runner.module == 'jobserver':
        js = JobManager('/tmp/oqmd-jobsever.pid')
        getattr(js, runner.task[0])()

    #======================================================================#
    if runner.module == 'taskserver':
        ts = TaskManager('/tmp/oqmd-tasksever.pid')
        getattr(ts, runner.task[0])()

    #======================================================================#
    if runner.task[0] == 'start_host':
        h = Host.objects.filter(name=runner.task[1])
        if h.exists():
            h = h[0]
            h.state = 1
            h.save()
            print 'Starting host: %s' % h
        else:
            print 'Host', runner.task[1], 'doesn\'t exist!'

    #======================================================================#
    if runner.task[0] == 'stop_host':
        h = Host.objects.filter(name=runner.task[1])
        if h.exists():
            h = h[0]
            h.state = -2
            h.save()
            print 'Disabling host: %s' % h
        else:
            print 'Host', runner.task[1], 'doesn\'t exist!'

    #======================================================================#
    if runner.task[0] == 'running':
        if len(runner.task) == 1:
            hosts = Host.objects.all().values_list('name', flat=True)
        else:
            hosts = runner.task[1:]

        print 'Currently running:'
        for h in hosts:
            print '!'+''.ljust(78, '=')+'!'
            print h.center(80)
            print '!'+''.ljust(78, '=')+'!'
            print 'PATH'.rjust(60), 'NCPUS'.ljust(20)
            host = Host.objects.get(name=h)
            for j in host.jobs:
                print j.entry.path.rjust(60),
                print str(j.ncpus).ljust(20)

    #======================================================================#
    if runner.task[0] == 'utilization':
        print 'Resource utilization:'
        for h in Host.objects.all():
            print '   - %s : %s' % (h.name, h.utilization)

    #======================================================================#
    if runner.task[0] == 'detail':
        if len(runner.task) == 1:
            tasks = Task.objects.all()
        else:
            tasks = Task.objects.filter(project=runner.task[1])
        print '%s tasks for this project' % tasks.count()
        print '   - FAILED:   %s' % tasks.filter(state=-1).count()
        print '   - RUNNING:  %s' % tasks.filter(state=1).count()
        print '   - COMPLETE: %s' % tasks.filter(state=2).count()
        print 
        tasks = Task.objects.filter(project=runner.task[1])
        print '%s tasks for this project' % tasks.count()
        num = float(tasks.count())
        print 'Progress:'
        for t in tasks:
            prog = t.entry.calculation_set.filter(done=True).count()
            print t.entry.name.rjust(20), '\t', t.entry.path,
            print ' \t [' +\
                    ''.join('##' for i in range(prog) )+\
                    ''.join('  ' for i in range(4-prog) )+\
                    ']'

    #======================================================================#
    if runner.task[0] == 'progress':
        if len(runner.task) == 1:
            projects = Project.objects.all().values_list('name', flat=True)
        else:
            projects = runner.task[1:]
        
        for project in projects:
            print 'Project: %s' % project
            tasks = Task.objects.filter(project=project)
            init = Calculation.objects.filter(done=True, settings='initialize', 
                    entry__task__project=project)
            cr = Calculation.objects.filter(done=True, settings='coarse_relax', 
                    entry__task__project=project)
            fr = Calculation.objects.filter(done=True, settings='fine_relax',
                    entry__task__project=project)
            std = Calculation.objects.filter(done=True, settings='standard',
                    entry__task__project=project)

            print '  %s tasks for this project' % tasks.count()
            print '     - FAILED:   %s' % tasks.filter(state=-1).count()
            print '     - RUNNING:  %s' % tasks.filter(state=1).count()
            print '     - COMPLETE: %s' % tasks.filter(state=2).count()
            num = float(tasks.count())
            print '  Progress:'
            print '     - initialized: {0:.3f}%'.format(init.count()/num*100)
            print '     - coarse relaxed: {0:.3f}%'.format(cr.count()/num*100)
            print '     - fine_relaxed: {0:.3f}%'.format(fr.count()/num*100)
            print '     - completed: {0:.3f}%'.format(std.count()/num*100)
            print 

    #======================================================================#
    if runner.task[0] == 'failed':
        if len(runner.task) == 1:
            tasks = Task.objects.filter(state=-1)
        else:
            project = runner.task[1]
            tasks = Task.objects.filter(project=project, state=-1)
        for t in tasks:
            print t.entry.name.rjust(10),
            errs = set()
            for err in t.entry.errors.values():
                errs |= set(err)
            print t.entry.path.ljust(30),
            print ', '.join(errs)

    #======================================================================#
    if runner.task[0] == 'cancel':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'check_queue':
        raise NotImplementedError

#==============================================================================#
#                       module: adhoc / calculate
#==============================================================================#
    if runner.task[0] in [ 
            'initialize', 
            'coarse_relax', 
            'fine_relax', 
            'standard'] :
        for struct in runner.task[1:]:
            structure = Structure.create(struct)
            path = os.path.abspath(struct)
            path = os.path.dirname(path)
            calc = Calculation.do(
                    input=structure,
                    type=runner.task[0],
                    path=path+'/'+runner.task[0])
            if calc.done:
                print 'Calculation of %s is done' % struct
            else:
                if calc.instructions:
                    calc.write()
                    print 'Wrote calculation of %s' % struct 

    #======================================================================#
    if runner.task[0] == 'kpoints':
        c = Calculation()
        c.POSCAR = runner.task[1]
        c.kppra = int(runner.task[2])
        print c.KPOINTS

#=============================================================================
# thermodynamic analysis
#==============================================================================#
    if runner.task[0] == 'vertical':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'reaction':
        for formula in runner.formula:
            s = PhaseSpace(formula+'-'+runner.task[1], load=include)
            for reaction in s.get_reactions(runner.task[1], 
                    electrons=kwargs.get('electrons', 1.0)):
                print reaction

    #======================================================================#
    if runner.task[0] == 'reaction_plot':
        for formula in runner.formula:
            print formula, '+', runner.task[1]
            s = PhaseSpace(formula+'-'+runner.task[1], load=include)
            s.plot_reactions(runner.task[1], 
                    electrons=kwargs.get('electrons', 1.0))
            plt.show()

    #======================================================================#
    if runner.task[0] == 'meta_stability':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'compounds':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'single_point':
        for rformula in runner.formula:
            formulae = parse_formula(rformula)
            for formula in formulae:
                print comp_to_name(formula, special='reduce'), ':'
                tot = sum(formula.values())
                formula = dict( (k, v/tot) for k, v in formula.items() )
                bstring = '-'.join(formula.keys())
                space = PhaseSpace(bstring, load=include)
                en, phases = space.gclp(formula)
                keys = sorted( phases.keys(), key=lambda x: -phases[x])
                print '   Hull formation energy: %.3f eV/atom' % en
                print '   Hull phases:', 
                pstr = ''
                for k in keys:
                    if phases[k] == 1:
                        pstr += k.name 
                    else:
                        pstr += '%s %s + ' % (phases[k], k.name)
                pstr = pstr.rstrip('+ ')
                print pstr,

                print '(',
                for k in keys:
                    print k.description,
                print ')'

                print '   Hull phases (LaTeX):',
                pstr = ''
                for k in keys:
                    if phases[k] == 1:
                        pstr += k.latex
                    else:
                        pstr += '%s %s + ' % (phases[k], k.latex)
                pstr = pstr.rstrip('+ ')
                print pstr,

                print '(',
                for k in keys:
                    print k.description,
                print ')'

    #======================================================================#
    if runner.task[0] == 'graph':
        for formula in runner.formula:
            bounds = formula.split('-')
            bounds = [ parse_formula(b) for b in bounds ]
            for region in itertools.product(*bounds):
                s = PhaseSpace(region, load=include)
                graph_plot(s)
            plt.show()

    #======================================================================#
    if runner.task[0] == 'phase_diagram':
        unstable = ( 'unstable' in args )
        for formula in runner.formula:
            bounds = formula.split('-')
            bounds = [ parse_formula(b) for b in bounds ]
            for region in itertools.product(*bounds):
                s = PhaseSpace(region, load=include)
                phase_diagram(s, unstable=unstable)
            plt.show()

    #======================================================================#
    if runner.task[0] == 'phase_diagram_script':
        unstable = ( 'unstable' in args )
        for formula in runner.formula:
            bounds = formula.split('-')
            bounds = [ parse_formula(b) for b in bounds ]
            for region in itertools.product(*bounds):
                s = PhaseSpace(region, load=include)
                print_script(s, unstable=unstable)

#==============================================================================#
# structure manipulation
#==============================================================================#
    if runner.task[0] == 'defects':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'symmetry':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'surface':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'vacancy':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'substitution':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'interstitial':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'replace':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'reshape':
        raise NotImplementedError

#==============================================================================#
# composition search
#==============================================================================#

    if runner.task[0] in ['query', 'search']:
        for formula in runner.formula:
            #if runner.task == ['count']:
            #    print query(formula=runner.formula, search=runner.search).count()
            if runner.task == ['summary']:
                if '-' in formula:
                    comps = Composition.get_space(formula)
                    for comp in comps:
                        comp.get_distinct(calculable=False)
                        print comp.summary
                else:
                    comp = Composition.get(formula)
                    print comp.summary
            elif runner.task == ['count']:
                phases = set()
                a = PhaseSpace(load=[])
                a._data.load_oqmd(fit='standard')
                for rformula in runner.formula:
                    formulae = parse_formula(rformula)
                    for formula in formulae:
                        tot = sum(formula.values())
                        formula = dict( (k, v/tot) for k, v in formula.items() )
                        bstring = '-'.join(formula.keys())
                        space = PhaseSpace(bstring, pdata=a._data)
                        phases |= set(space.phases)
                print '\n'.join([ p.name for p in phases])

            else:
                result = query(type=runner.task[0], 
                        formula=runner.formula, 
                        search=runner.search,
                        columns=runner.task[1:])
                print result

if __name__ == '__main__':
    main()
    if os.path.exists('gurobi.log'):
        os.unlink('gurobi.log')
