#  _________________________________________________________________________
#
#  PyUtilib: A Python utility library.
#  Copyright (c) 2008 Sandia Corporation.
#  This software is distributed under the BSD License.
#  Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
#  the U.S. Government retains certain rights in this software.
#  _________________________________________________________________________

# Q: When passing options, these values do not initialize the startup task.
#    Is this a bug?
# TODO: only set option values for variables that show up in a workflow's inputs
# TODO: add graceful management of exceptions
#       show the task tree, etc...

__all__ = ['Workflow']

from task import Task, EmptyTask, NoTask
import argparse
from pyutilib.misc import Options
import pprint
from collections import deque


def _collect_parser_groups(t):
    for key in t._parser_group:
        #
        # NOTE: we are changing the properties of the group
        # instances here.  This is OK _only_ because we are
        # printing the help info and then terminating.
        #
        t._parser_group[key].parser = parser
        parser.add_argument_group(t._parser_group[key])

def _set_arguments(t):
    for arg in t._parser_arg:
        #print 'x',arg
        args = arg[0]
        kwargs = arg[1]
        try:
            t._parser.add_argument(*args, **kwargs)
        except argparse.ArgumentError:
            pass


class Workflow(Task):

    def __init__(self, id=None, name=None, parser=None):
        Task.__init__(self, id=id, name=name, parser=None)
        self._tasks = {}
        self._start_task = EmptyTask()
        self._final_task = EmptyTask()
        self.add(self._start_task)
        self.add(self._final_task)

    def add(self, task, loadall=True):
        #print "ADD",task.id
        if task.id == NoTask.id:
            return
        if task.id in self._tasks:
            return
        self._tasks[task.id] = task
        if not loadall:
            return
        #print 'xx',task.inputs.keys()
        for name in task.inputs:
            for t in task.inputs[name].from_tasks():
                self.add(t)
            ##t = task.inputs[name].from_task()
            #print "x",name,t.id
            
            if len(task.inputs[name].from_tasks()) > 0:
                continue

            if not name in self._start_task.outputs:
                self._start_task.outputs.declare(name)
                #print "Declaring input "+name
                self.inputs.declare(name, optional=True)
            #print "Z",self._start_task.outputs
            #print "Z",self._start_task.outputs[name]
            # TODO: this is a bit of a hack...
            val = getattr(task.inputs,name).get_value()
            try:
                setattr(task.inputs, name, getattr(self._start_task.outputs, name))
            except ValueError:
                pass
            getattr(self.inputs,name).set_value(val)
            #getattr(self._start_task.outputs,name).set_value(val)
            #print "Z",task.inputs
            #print "Z",self._start_task.outputs[name].from_task.id
            #print "Z",self._start_task.outputs[name].to_task.id
            #print "ZZZ",self._start_task.outputs
        #print 'xx',task.inputs.keys()
        for name in task.output_controls:
            for c in task.output_controls[name].output_connections:
                self.add(c.to_port.task)
        #
        for name in task.outputs:
            if len(task.outputs[name].output_connections) > 0:
                for c in task.outputs[name].output_connections:
                    self.add(c.to_port.task)
            else:
                if name in self._final_task.inputs:
                    raise ValueError, "Cannot declare a workplan with multiple output values that share the same name: %s" % name
                self.outputs.declare(name)
                self._final_task.inputs.declare(name)
                setattr(self._final_task.inputs, name, task.outputs[name])

    def _call_init(self, *options, **kwds):
        Task._call_init(self, *options, **kwds)
        for i in self.inputs:
            #print "_call_init",self.name,i, self.inputs[i].get_value()
            val = self.inputs[i].get_value()
            if not val is None:
                self._start_task.outputs[i].set_value( val )
                self._start_task.outputs[i].set_ready()
        # Q: is this redundant???
        for key in kwds:
            if key not in self._start_task.outputs:
                raise ValueError, "Cannot specify value for option %s.  Valid option names are %s" % (key, self._start_task.outputs.keys())
            self._start_task.outputs[key].set_value( kwds[key] )

    def _call_fini(self, *options, **kwds):
        ans = Options()
        for key in self._final_task.inputs:
            self._final_task.inputs[key].compute_value()
            ans[key] = self._final_task.inputs[key].get_value()
            getattr(self.outputs, key).set_value( ans[key] )
        for key in self.outputs:
            self.outputs[key].set_ready()
        return ans

    def set_options(self, args):
        self._dfs_([self._start_task.id], lambda t: t.set_options(args))

    def options(self):
        return self._start_task.outputs.keys()

    def print_help(self):
        parser = argparse.ArgumentParser()
        self._dfs_([self._start_task.id], _collect_parser_groups)
        parser.print_help()

    def set_arguments(self, parser=None):
        if parser is None:
            parser = self._parser
        self._dfs_([self._start_task.id], _set_arguments)

    def reset(self):
        return self._dfs_([self._start_task.id], lambda t: t.reset())

    def execute(self):
        #return self._dfs_([self._start_task.id], lambda t: t.__call__())
        #print '---------------'
        #print '---------------'
        #print repr(self)
        #print '---------------'
        queued = set([self._start_task.id])
        queue = deque([self._start_task])
        waiting = {}
        while len(queue) > 0:
            task = queue.popleft()
            #print self.name, "TASK   ",repr(task)
            ##print self.name, "QUEUE  ",queued
            ##print self.name, "WAITING",waiting.keys()
            ##print "Executing Task "+task.name,task.next_task_ids()
            queued.remove(task.id)
            task()
            for id in waiting.keys():
                t = waiting[id]
                if not t.id in queued and t.ready():
                    ##print self.name, "Waiting task",t.name,t.id,t.ready()
                    #print self.name, "HERE",t,repr(t)
                    queue.append(t)
                    queued.add(t.id)
                    del waiting[t.id]
            for t in task.next_tasks():
                if t.id in queued:
                    continue
                if t.ready():
                    ##print self.name, "Scheduling task",t.name,t.id,t.ready()
                    #print self.name, "HERE",t,repr(t)
                    queue.append(t)
                    queued.add(t.id)
                    if t.id in waiting:
                        del waiting[t.id]
                else:
                    waiting[t.id] = t
                    ##print self.name, "Ignoring task",t.name, t.id,t.ready()
                    #print self.name, "HERE",t,repr(t)
            #print self.name, "QUEUE  ",queue
            #print self.name, "WAITING",waiting.keys()
        #print '---------------'

    def __str__(self):
        return "\n".join(["Workflow %s:" % self.name]+self._dfs_([self._start_task.id], lambda t: str(t)))

    def __repr__(self):
        return "Workflow %s:\n" % self.name+Task.__repr__(self)+'\n'+"\n".join(self._dfs_([self._start_task.id], lambda t: repr(t)))

    def _dfs_(self, indices, fn, touched=None):
        if touched is None:
            touched = set()
        ans=[]
        #print "Z",indices
        for i in indices:
            if i in touched:
                # With this design, this condition should never be triggered
                # TODO: verify that this is an O(n) search algorithm; I think it's
                # O(n^2)
                continue        #pragma:nocover
            ok=True
            #print "X",i
            #print "Y all keys",self._tasks.keys()
            #print "Y prev params",self._tasks[i].inputs.keys()
            #print "Y next paramsids",self._tasks[i].outputs.keys()
            #print "Y prev ids",self._tasks[i].prev_task_ids()
            #print "Y next ids",self._tasks[i].next_task_ids()
            task = self._tasks[i]
            for j in task.prev_task_ids():
                if j is NoTask.id or j in touched:
                    continue
                ok=False
                break
            #print "X",ok
            if not ok:
                continue
            tmp = fn(task)
            if not tmp is None:
                ans.append(tmp)
            touched.add(i)
            ans = ans + self._dfs_(task.next_task_ids(), fn, touched)    
        return ans

