#!/usr/bin/python

import sys
sys.path.append('/home/sjk648/bin')
import argparse
import os.path
import itertools
from qmpy import *
from thermopy import PhaseSpace
#from qmpy.utils.query import query

def main():
    '''Handles command line args, parses them to understand what sub-scripts to
    run and handles displaying the results.'''

    parser = argparse.ArgumentParser(prog='oqmd',
            formatter_class=argparse.RawTextHelpFormatter,
            description='''
Description goes here.
''')
    parser.add_argument('module',default='query')
    parser.add_argument('--task','-T',
            nargs='*', default='',
            help='')
    parser.add_argument('--formula','-F',
            nargs='*', default='',
            help='')
    parser.add_argument('--search','-S',
            nargs='*', default='',
            help='')
    parser.add_argument('--args', '-A',
            nargs='*', default='',
            help='')
    parser.add_argument('--include', '-I',
            nargs='*', default='',
            help='')
    parser.add_argument('--kwargs','-K',
            nargs='*', default='',
            help='')
    parser.add_argument('--file',
            nargs='*', default='',
            help='')
    parser.add_argument('--data',
            nargs='*', default='',
            help='')

    runner = parser.parse_args(sys.argv[1:])
    if not runner.task:
        if runner.module == 'gclp':
            runner.task = ['single_point']
        else:
            return parser.parse_args(['--help'])

    kwargs = {}
    for k,v in [ kvpair.split('=') for kvpair in runner.kwargs] :
        kwargs[k] = assign_type(v)
    runner.kwargs = kwargs

    args = []
    for arg in runner.args:
        args.append(assign_type(arg))
    runner.args = args

    include = []
    for arg in runner.include:
        include.append(arg)
    if not include:
        include = ['oqmd']

#==============================================================================#
# 
#                       module: config
#
#==============================================================================#
    if runner.module == 'config':

        #======================================================================#
        if runner.task[0] == 'add_host':
            host = {}
            host['name'] = raw_input('Hostname:')
            if Host.objects.filter(name=host['name']).exists():
                print 'Host by that name already exists!'
                exit(-1)
            host['ip_address'] = raw_input('IP Address:')
            if Host.objects.filter(ip_address=host['ip_address']).exists():
                print 'Host at that address already exists!'
                exit(-1)
            host['ppn'] = raw_input('Processors per node:')
            host['nodes'] = raw_input('Max nodes to run on:')
            host['sub_script'] = raw_inputs('Command to submit a script '+
                    '(e.g. /usr/local/bin/qsub):')
            host['check_queue'] = raw_input('Command for showq (e.g.'+
                    '/usr/local/maui/bin/showq):')
            host['sub_text'] = raw_input('Path to qfile template:')
            h = Host(**host)
            h.save()

        #======================================================================#
        if runner.task[0] == 'add_user':
            user = {}
            user['name'] = raw_input('OQMD Username')
            if User.objects.filter(name=user['name']).exists():
                print 'User by that name already exists!'
                exit(-1)
            user = User(**user)
            user.save()
            print 'Add accounts now, if you don\'t have an account on a given'
            print 'host, leave blank.'
            for h in Host.objects.all():
                name = raw_input('On %s, you username is:' % h.name)
                if not name:
                    continue
                else:
                    p = raw_input('On %s, you want to do calculations at '+
                            '(e.g. /home/sjk648/auto_run)')
                    acct = Account(user=user, username=name, run_path=p)
                    acct.save()

        #======================================================================#
        if runner.task[0] == 'add_allocation':
            name = raw_input('Name your allocation:')
            if Allocation.objects.filter(name=name).exists():
                print 'Allocation by that name already exists!'
                exit(-1)
            host = raw_input('Which cluster is this allocation on?')
            if not Host.objects.filter(name=host).exists():
                print "This host doesn't exist!"
                exit(-1)
            host = Host.objects.get(name=host)
            alloc = Allocation(name=name, host=host)
            alloc.save()
            print 'Now we will assign users to this allocation'
            for acct in Account.objects.filter(host=host):
                inc = raw_input('Can %s use this allocation? y/n [y]:' % 
                        acct.user.name )
                if inc == 'y' or inc == '':
                    alloc.users.add(acct.user)
            print 'If this allocation requires a special password, enter',
            key = raw_input('it now:')
            alloc.key=key
            alloc.save()

        #======================================================================#
        if runner.task[0] == 'add_project':
            name = raw_input('Name your project: ')
            if Project.objects.filter(name=name).exists():
                print 'Project by that name already exists!'
                exit(-1)
            proj = Project(name=name)
            proj.save()
            proj.priority = raw_input('Project priority (1-100): ')
            users = raw_input('List project users (e.g. sjk648 jsaal531 bwm291): ')
            for u in users.split():
                if not User.objects.filter(name=u).exists():
                    print 'User named', u, 'doesn\'t exist!'
                else:
                    proj.users.add(User.objects.get(name=u))

            alloc = raw_input('List project allocations (e.g. byrd victoria b1004): ')
            for a in alloc.split():
                if not Allocation.objects.filter(name=a).exists():
                    print 'Allocation named', a, 'doesn\'t exist!'
                else:
                    proj.allocations.add(Allocation.objects.get(name=a))

#==============================================================================#
# 
#                       module: database
#
#==============================================================================#
    if runner.module == 'database':
        if runner.task[0] == 'describe':
            raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'add_cifs':
            cifs = os.listdir('/home/sjk648/storage/add_cifs')
            for cif in cifs:
                cifnum = cif.replace('.cif','')
                print cifnum,
                cifpath = '/home/sjk648/libraries/icsd/%s' % cifnum
                if Entry.objects.filter(path=cifpath).exists():
                    print 'already exists. It may be an update?',
                    print 'I really don\'t know...'
                else:
                    f = open('/home/sjk648/storage/add_cifs/%s' % cif)
                    if len(f.readlines()) > 300:
                        print 'is really long'
                        continue
                    print 'is being added'
                    os.system('mkdir %s &> /dev/null' % cifpath)
                    os.system('cp /home/sjk648/storage/add_cifs/%s %s/%s' % (
                        cif, cifpath, cif))
                    entry = Entry.create(cifpath+'/'+cif,
                            project='icsd')
                    entry.save()


        #======================================================================#
        if runner.task[0] == 'add_structure':
            if not kwargs.get('project', False):
                runner.kwargs['project'] = raw_input('project: ')
            if not kwargs.get('keywords', False):
                runner.kwargs['keywords'] = raw_input('keywords: ').split()
            path = os.path.abspath(runner.task[1])
            try:
                entry = Entry.create(path,
                        keywords=runner.kwargs['keywords'],
                        project=runner.kwargs['project'])
                entry.save()
            except:
                print 'Failed to add structure file: %s' % path
                print 'Please verify that it is a valid structure file'

        #======================================================================#
        if runner.task[0] == 'add_directory':
            if 'project' not in runner.kwargs:
                runner.kwargs['project'] = raw_input('project: ')
            if 'keywords' not in runner.kwargs:
                runner.kwargs['keywords'] = raw_input('keywords: ')
            path = os.path.abspath(runner.task[1])
            for file in os.listdir(path):
                try:
                    entry = Entry.create(path,
                        keywords=runner.kwargs['project'],
                        project=runner.kwargs['keywords'])
                    entry.save()
                except:
                    print 'Failed to add structure file: %s' % path
                    print 'Please verify that it is a valid structure file'

#==============================================================================#
# 
#                       module: queue
#
#==============================================================================#
    if runner.module == 'queue':

        #======================================================================#
        if runner.task[0] == 'start_host':
            h = Host.objects.filter(name=runner.task[1])
            if h.exists():
                h = h[0]
                h.state = 1
                h.save()
                print 'Starting host: %s' % h
            else:
                print 'Host', runner.task[1], 'doesn\'t exist!'

        #======================================================================#
        if runner.task[0] == 'stop_host':
            h = Host.objects.filter(name=runner.task[1])
            if h.exists():
                h = h[0]
                h.state = -2
                h.save()
                print 'Disabling host: %s' % h
            else:
                print 'Host', runner.task[1], 'doesn\'t exist!'

        #======================================================================#
        if runner.task[0] == 'running':
            if len(runner.task) == 1:
                hosts = Host.objects.all().values_list('name', flat=True)
            else:
                hosts = runner.task[1:]

            print 'Currently running:'
            for h in hosts:
                print '!'+''.ljust(78, '=')+'!'
                print h.center(80)
                print '!'+''.ljust(78, '=')+'!'
                print 'PATH'.rjust(60), 'NCPUS'.ljust(20)
                host = Host.objects.get(name=h)
                for j in host.jobs:
                    print j.entry.path.rjust(60),
                    print str(j.ncpus).ljust(20)

        #======================================================================#
        if runner.task[0] == 'utilization':
            print 'Resource utilization:'
            for h in Host.objects.all():
                print '   - %s : %s' % (h.name, h.utilization)

        #======================================================================#
        if runner.task[0] == 'detail':
            if len(runner.task) == 1:
                tasks = Task.objects.all()
            else:
                tasks = Task.objects.filter(project=runner.task[1])
            print '%s tasks for this project' % tasks.count()
            print '   - FAILED:   %s' % tasks.filter(state=-1).count()
            print '   - RUNNING:  %s' % tasks.filter(state=1).count()
            print '   - COMPLETE: %s' % tasks.filter(state=2).count()
            print 
            tasks = Task.objects.filter(project=runner.task[1])
            print '%s tasks for this project' % tasks.count()
            num = float(tasks.count())
            print 'Progress:'
            for t in tasks:
                prog = t.entry.calculation_set.filter(done=True).count()
                print t.entry.name.rjust(20), '\t', t.entry.path,
                print ' \t [' +\
                        ''.join('##' for i in range(prog) )+\
                        ''.join('  ' for i in range(4-prog) )+\
                        ']'

        #======================================================================#
        if runner.task[0] == 'progress':
            if len(runner.task) == 1:
                projects = Project.objects.all().values_list('name', flat=True)
            else:
                projects = runner.task[1:]
            
            for project in projects:
                print 'Project: %s' % project
                tasks = Task.objects.filter(project=project)
                init = Calculation.objects.filter(done=True, settings='initialize', 
                        entry__task__project=project)
                cr = Calculation.objects.filter(done=True, settings='coarse_relax', 
                        entry__task__project=project)
                fr = Calculation.objects.filter(done=True, settings='fine_relax',
                        entry__task__project=project)
                std = Calculation.objects.filter(done=True, settings='standard',
                        entry__task__project=project)

                print '  %s tasks for this project' % tasks.count()
                print '     - FAILED:   %s' % tasks.filter(state=-1).count()
                print '     - RUNNING:  %s' % tasks.filter(state=1).count()
                print '     - COMPLETE: %s' % tasks.filter(state=2).count()
                num = float(tasks.count())
                print '  Progress:'
                print '     - initialized: {0:.3f}%'.format(init.count()/num*100)
                print '     - coarse relaxed: {0:.3f}%'.format(cr.count()/num*100)
                print '     - fine_relaxed: {0:.3f}%'.format(fr.count()/num*100)
                print '     - completed: {0:.3f}%'.format(std.count()/num*100)
                print 

        #======================================================================#
        if runner.task[0] == 'failed':
            if len(runner.task) == 1:
                tasks = Task.objects.filter(state=-1)
            else:
                project = runner.task[1]
                tasks = Task.objects.filter(project=project, state=-1)
            for t in tasks:
                print t.entry.name.rjust(10),
                errs = set()
                for err in t.entry.errors.values():
                    errs |= set(err)
                print t.entry.path.ljust(30),
                print ', '.join(errs)

        #======================================================================#
        if runner.task[0] == 'cancel':
            raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'check_queue':
            raise NotImplementedError

#==============================================================================#
# 
#                       module: adhoc / calculate
#
#==============================================================================#
    if runner.module == 'adhoc':

        #======================================================================#
        if runner.task[0] in [ 
                'initialize', 
                'coarse_relax', 
                'fine_relax', 
                'standard'] :
            for struct in runner.task[1:]:
                structure = Structure.create(struct)
                path = os.path.abspath(struct)
                path = os.path.dirname(path)
                calc = Calculation.do(
                        input=structure,
                        type=runner.task[0],
                        path=path+'/'+runner.task[0])
                if calc.done:
                    print 'Calculation of %s is done' % struct
                else:
                    if calc.instructions:
                        calc.write()
                        print 'Wrote calculation of %s' % struct 

        #======================================================================#
        if runner.task[0] == 'kpoints':
            c = Calculation()
            c.POSCAR = runner.task[1]
            c.kppra = int(runner.task[2])
            print c.KPOINTS

    if runner.module == 'calculate':
        for struct in runner.task:
            e = Entry.create(struct)
            e.save()
            job = e.do('standard')

#==============================================================================#
# 
#                       module: fitting
#
#==============================================================================#
    if runner.module == 'fitting':
        raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'mus':
            raise NotImplementedError

#==============================================================================#
# 
#                       module: gclp
#
#==============================================================================#
    if runner.module == 'gclp':

        #======================================================================#
        if runner.task[0] == 'vertical':
            raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'reaction':
            for formula in runner.formula:
                s = PhaseSpace(formula+'-'+runner.task[1], load=include)
                for reaction in s.get_reactions(runner.task[1], 
                        electrons=kwargs.get('electrons', 1.0)):
                    print reaction

        #======================================================================#
        if runner.task[0] == 'reaction_plot':
            for formula in runner.formula:
                print formula, '+', runner.task[1]
                s = PhaseSpace(formula+'-'+runner.task[1], load=include)
                s.plot_reactions(runner.task[1], 
                        electrons=kwargs.get('electrons', 1.0))
                plt.show()

        #======================================================================#
        if runner.task[0] == 'meta_stability':
            raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'compounds':
            raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'single_point':
            for rformula in runner.formula:
                formulae = parse_formula(rformula)
                for formula in formulae:
                    print comp_to_name(formula, special='reduce'), ':'
                    tot = sum(formula.values())
                    formula = dict( (k, v/tot) for k, v in formula.items() )
                    bstring = '-'.join(formula.keys())
                    space = PhaseSpace(bstring, load=include)
                    en, phases = space.gclp(formula)
                    keys = sorted( phases.keys(), key=lambda x: -phases[x])
                    print '   Hull formation energy: %.3f eV/atom' % en
                    print '   Hull phases:', 
                    pstr = ''
                    for k in keys:
                        if phases[k] == 1:
                            pstr += k.name 
                        else:
                            pstr += '%s %s + ' % (phases[k], k.name)
                    pstr = pstr.rstrip('+ ')
                    print pstr,

                    print '(',
                    for k in keys:
                        print k.description,
                    print ')'

                    print '   Hull phases (LaTeX):',
                    pstr = ''
                    for k in keys:
                        if phases[k] == 1:
                            pstr += k.latex
                        else:
                            pstr += '%s %s + ' % (phases[k], k.latex)
                    pstr = pstr.rstrip('+ ')
                    print pstr,

                    print '(',
                    for k in keys:
                        print k.description,
                    print ')'

        #======================================================================#
        if runner.task[0] == 'graph':
            for formula in runner.formula:
                bounds = formula.split('-')
                bounds = [ parse_formula(b) for b in bounds ]
                for region in itertools.product(*bounds):
                    s = PhaseSpace(region, load=include)
                    graph_plot(s)
                plt.show()

        #======================================================================#
        if runner.task[0] == 'phase_diagram':
            unstable = ( 'unstable' in args )
            for formula in runner.formula:
                bounds = formula.split('-')
                bounds = [ parse_formula(b) for b in bounds ]
                for region in itertools.product(*bounds):
                    s = PhaseSpace(region, load=include)
                    phase_diagram(s, unstable=unstable)
                plt.show()

        #======================================================================#
        if runner.task[0] == 'phase_diagram_script':
            unstable = ( 'unstable' in args )
            for formula in runner.formula:
                bounds = formula.split('-')
                bounds = [ parse_formula(b) for b in bounds ]
                for region in itertools.product(*bounds):
                    s = PhaseSpace(region, load=include)
                    print_script(s, unstable=unstable)

#==============================================================================#
# 
#                       module: cell
#
#==============================================================================#
    if runner.module == 'cell':

        #======================================================================#
        if runner.task[0] == 'defects':
            raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'symmetry':
            raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'surface':
            raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'vacancy':
            raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'substitution':
            raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'interstitial':
            raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'replace':
            raise NotImplementedError

        #======================================================================#
        if runner.task[0] == 'reshape':
            raise NotImplementedError

#==============================================================================#
# 
#                       module: search
#
#==============================================================================#

    if runner.module in ['query','search']:
        for formula in runner.formula:
            #if runner.task == ['count']:
            #    print query(formula=runner.formula, search=runner.search).count()
            if runner.task == ['summary']:
                if '-' in formula:
                    comps = Composition.get_space(formula)
                    for comp in comps:
                        comp.get_distinct(calculable=False)
                        print comp.summary
                else:
                    comp = Composition.get(formula)
                    print comp.summary
            elif runner.task == ['count']:
                phases = set()
                a = PhaseSpace(load=[])
                a._data.load_oqmd(fit='standard')
                for rformula in runner.formula:
                    formulae = parse_formula(rformula)
                    for formula in formulae:
                        tot = sum(formula.values())
                        formula = dict( (k, v/tot) for k, v in formula.items() )
                        bstring = '-'.join(formula.keys())
                        space = PhaseSpace(bstring, pdata=a._data)
                        phases |= set(space.phases)
                print '\n'.join([ p.name for p in phases])

            else:
                result = query(type=runner.task[0], 
                        formula=runner.formula, 
                        search=runner.search,
                        columns=runner.task[1:])
                print result

if __name__ == '__main__':
    main()
    if os.path.exists('gurobi.log'):
        os.unlink('gurobi.log')
