#  _________________________________________________________________________
#
#  PyUtilib: A Python utility library.
#  Copyright (c) 2008 Sandia Corporation.
#  This software is distributed under the BSD License.
#  Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
#  the U.S. Government retains certain rights in this software.
#  _________________________________________________________________________

# Q: When passing options, these values do not initialize the startup task.
#    Is this a bug?
# TODO: only set option values for variables that show up in a workflow's inputs
# TODO: add graceful management of exceptions
#       show the task tree, etc...

__all__ = ['Workflow']

from task import Task, EmptyTask, NoTask
from options import OptionParser
from pyutilib.misc import Options
import pprint


def _collect_parser_groups(t):
        for key in t._parser_group:
            #
            # NOTE: we are changing the properties of the group
            # instances here.  This is OK _only_ because we are
            # printing the help info and then terminating.
            #
            t._parser_group[key].parser = parser
            parser.add_option_group(t._parser_group[key])


class Workflow(Task):

    def __init__(self, id=None, name=None):
        Task.__init__(self, id=id, name=name)
        self._tasks = {}
        self._start_task = EmptyTask()
        self._final_task = EmptyTask()
        self.add(self._start_task)
        self.add(self._final_task)

    def add(self, task, loadall=True):
        #print "ADD",task.id
        if task.id == NoTask.id:
            return
        if task.id in self._tasks:
            return
        self._tasks[task.id] = task
        if not loadall:
            return
        #print 'xx',task.inputs.keys()
        for name in task.inputs:
            for t in task.inputs[name].from_tasks():
                self.add(t)
            ##t = task.inputs[name].from_task()
            #print "x",name,t.id
            
            if not name in self._start_task.outputs:
                self._start_task.outputs.declare(name)
                #print "Declaring input "+name
                self.inputs.declare(name, optional=True)
            #print "Z",self._start_task.outputs
            #print "Z",self._start_task.outputs[name]
            # TODO: this is a bit of a hack...
            val = getattr(task.inputs,name).get_value()
            try:
                setattr(task.inputs, name, getattr(self._start_task.outputs, name))
            except ValueError:
                pass
            getattr(self.inputs,name).set_value(val)
            #getattr(self._start_task.outputs,name).set_value(val)
            #print "Z",task.inputs
            #print "Z",self._start_task.outputs[name].from_task.id
            #print "Z",self._start_task.outputs[name].to_task.id
            #print "ZZZ",self._start_task.outputs
        #print 'xx',task.inputs.keys()
        for name in task.outputs:
            if len(task.outputs[name].output_connections) > 0:
                for c in task.outputs[name].output_connections:
                    self.add(c.to_port.task)
            else:
                if name in self._final_task.inputs:
                    raise ValueError, "Cannot declare a workplan with multiple output values that share the same name: %s" % name
                self.outputs.declare(name)
                self._final_task.inputs.declare(name)
                setattr(self._final_task.inputs, name, task.outputs[name])

    def _call_init(self, *options, **kwds):
        Task._call_init(self, *options, **kwds)
        for i in self.inputs:
            self._start_task.outputs[i].set_value( self.inputs[i].get_value() )
        # Q: is this redundant???
        for key in kwds:
            if key not in self._start_task.outputs:
                raise ValueError, "Cannot specify value for option %s.  Valid option names are %s" % (key, self._start_task.outputs.keys())
            self._start_task.outputs[key].set_value( kwds[key] )

    def _call_fini(self, *options, **kwds):
        ans = Options()
        for key in self._final_task.inputs:
            self._final_task.inputs[key].compute_value()
            ans[key] = self._final_task.inputs[key].get_value()
            getattr(self.outputs, key).set_value( ans[key] )
        return ans

    def set_options(self, args):
        self._dfs_([self._start_task.id], lambda t: t.set_options(args))

    def options(self):
        return self._start_task.outputs.keys()

    def print_help(self):
        parser = OptionParser()
        self._dfs_([self._start_task.id], _collect_parser_groups)
        parser.print_help()

    def execute(self):
        self._dfs_([self._start_task.id], lambda t: t.__call__())

    def __str__(self):
        return "\n".join(["Workflow:"]+self._dfs_([self._start_task.id], lambda t: str(t)))

    def __repr__(self):
        return "\n".join(["Workflow:"]+self._dfs_([self._start_task.id], lambda t: repr(t)))

    def _dfs_(self, indices, fn, touched=None):
        if touched is None:
            touched = set()
        ans=[]
        for i in indices:
            if i in touched:
                # With this design, this condition should never be triggered
                # TODO: verify that this is an O(n) search algorithm; I think it's
                # O(n^2)
                continue        #pragma:nocover
            ok=True
            #print "X",i
            #print "Y all keys",self._tasks.keys()
            #print "Y prev params",self._tasks[i].inputs.keys()
            #print "Y next paramsids",self._tasks[i].outputs.keys()
            #print "Y prev ids",self._tasks[i].prev_task_ids()
            #print "Y next ids",self._tasks[i].next_task_ids()
            task = self._tasks[i]
            for j in task.prev_task_ids():
                if j is NoTask.id or j in touched:
                    continue
                ok=False
                break
            if not ok:
                continue
            tmp = fn(task)
            if not tmp is None:
                ans.append(tmp)
                #ans.append(fn(task))
            touched.add(i)
            ans = ans + self._dfs_(task.next_task_ids(), fn, touched)    
        return ans

