
"""
The global state for interpreted predicates. It stores tool wrappers,
but also which goals have been interpreted.  Cut down etb3 version
"""
import os, threading, sys, traceback
import terms, parser, wrapper
import logging, inspect


from datalog import model

class InterpretState(object):
    """
    The state for interpretation of predicates.

    It depends on an ETB instance to add claims after a predicate has
    been interpreted, and because some tool wrappers may need
    this ETB instance.
    """

    def __init__(self, etb):
        "initializes the state"
        self.log = logging.getLogger('etb.interpret_state')
        self.etb = etb

        # handler that interpret predicates
        self._handlers = {}
        self._being_interpreted = {}

        # predicates already interpreted somewhere -> set of result
        self.results = {}

        # concurrency mechanisms
        self._rlock = threading.RLock()
        
    def __enter__(self):
        self._rlock.acquire()

    def __exit__(self, t, v, tb):
        self._rlock.release()

    def __repr__(self):
        return '  Interpreted predicates:\n' + '\n'.join(
            ['    ' + str(k) +
                ('(*)' if k.is_volatile() else '')
                for k in self._handlers])

    def _import_wrapper(self, wrapper_name):
        "import one wrapper and register it"
        try:
            mod = __import__(wrapper_name, fromlist=['register'])
            self.log.info("register tool wrapper {0}".format(wrapper_name))
            # call the "register" function of the module with the ETB instance as argument
            getattr(mod, 'register')(self.etb)
        except Exception as e:
            self.log.error("error while importing wrapper {0}: {1}" . \
                                  format(wrapper_name, e))

    def _get_handler(self, goal, default=None):
        """Get the handler that should be used to interpret
        the given goal."""
        symbol = goal.first_symbol()
        return self._handlers.get(symbol, default)

    def _validate_fileref(self, arg):
        if arg.is_ground() :
            try:
                args = arg.get_args()
                fileref = { 'file' : str(args[terms.mk_term('file')]),
                            'sha1' : str(args[terms.mk_term('sha1')]) }
            except Exception as e :
                self.log.error('e: %s' % e)
                self.log.error('tb: %s' %  traceback.format_exc(None))
                self.log.error('Invalid file reference: %s' % arg)
                raise e
            if not self.etb.git.is_local(fileref):
                self.etb.get_file_from_somewhere(fileref)
            return fileref
        else:
            return arg

    def _validate_handle(self, arg):
        if arg.is_ground():
            try:
                handle = { 'etb' : str(arg.get_args()[terms.mk_term('etb')]),
                           'tool' : str(arg.get_args()[terms.mk_term('tool')]),
                           'session' : arg.get_args()[terms.mk_term('session')],
                           'timestamp':arg.get_args()[terms.mk_term('timestamp')]}
            except Exception as e :
                self.log.error('Invalid handle: %s' % arg)
                raise e
            return handle
        else:
            return arg


    def is_valid(self, goal):
        try:
            if goal is not None:
                pred = goal.first_symbol()
                val = str(pred.val)
                if val in self.predicates():
                    self._validate_args(goal)
            return True
        except Exception as e:
            self.log.error('Goal %s is interpreted but not valid with error %s' % (goal, e))
            return False

    def _validate_args(self, goal):
        pred = goal.first_symbol()
        argspecstr = self.predicates()[str(pred.val)]
        argspecs = wrapper.ArgSpec.parse(argspecstr)

        goal_args = goal.get_args()
        
        if len(argspecs) != len(goal_args):
            self.log.error('Invalid number of arguments')
            raise Exception('Invalid number of arguments to %s' % pred)

        args = []
        for argspec, arg in zip(argspecs,goal_args):
            if argspec.mode == '+' and arg.is_var():
                error_string = 'Argument %s should be a value, not %s; validate_args unhappy' % (argspec, arg)
                self.log.debug(error_string)
                assert False, error_string
            
            if argspec.kind == 'file':
                args.append(self._validate_fileref(arg))

            elif argspec.kind == 'files':
                if arg.is_ground() :
                    arglist =  arg.get_args()
                    arg = [ self._validate_fileref(a) for a in arglist ]
                args.append(arg)
                
            elif argspec.kind == 'handle':
                args.append(self._validate_handle(arg))

            else :
                args.append(arg)

        return args

    def _interpret(self, goal, handler):
        """
        Actually interpret the goal using the handler. Results are
        added to the engine. The goal will be interpreted only if
        it has not yet been interpreted.
        """

        self.log.info('Interpreter: %s', goal.first_symbol())

        if goal in self.results and not goal.first_symbol().is_volatile():
            claims = tuple(self.results[goal])
            self.add_results(goal, claims)
            return
        try:
            args = self._validate_args(goal)
        except Exception as e:
            self.log.info('e = %s' % e)
            self.add_results(goal, [])
            return

        try:
            self.log.info('Calling %s', goal.first_symbol())

            output = handler(*args)

        except Exception as e:
            self.log.error('While interpreting {0}, error {1}'.format(goal, e))
            output = { 'claims' : 'error("%s", "%s")' % (goal.first_symbol(), e)}

        if output is None:
            self.log.error('Nothing returned for goal {0}'.format(goal))
            self.log.error('Did you forget to return an output?')
            output = { 'claims' : 'error("%s", "%s")' % 
                       (goal.first_symbol(), 'Nothing returned for goal')}

        self.log.info('_interpret: {0}'.format(output))

        if isinstance(output, wrapper.Result):
            self._handle_output_new_api(goal, output)
        elif isinstance(output, list):
            self._process_output(goal, output)
        else:            
            self.log.error('ETB3 wrappers only return lists of substitutions')
            raise Exception('ETB3 wrappers only return lists of substitutions')

    def _handle_output_new_api(self, goal, output):
        self.log.info('_handle_output_new_api: output {0}'.format(output))
        rules = output.get_pending_rules(goal)
        self.log.info('_handle_output_new_api: rules {0}'.format(rules))
        if not rules:
            claims = output.get_claims(goal)
            self.log.info('_handle_output_new_api: claims {0}'.format(claims))
            if not claims:
                self.etb.engine.push_no_solutions(goal)
            else: 
                self.log.info('_handle_output_new_api: adding claims {0}'.format(claims))
                if isinstance(output, wrapper.Errors):
                    self.etb.engine.add_errors(goal, claims)
                    self.etb.engine.push_no_solutions(goal)
                else:
                    self.etb.engine.add_claims(claims)

            self.add_results(goal, claims)
        else:
            for r in rules : 
                self.log.info('Adding new rule: %s', r)
                self.etb.engine.add_pending_rule(r, goal)
            self.add_results(goal, [])

    def _process_output(self, goal, output):
        self.log.info('_process_output: goal = %s output = %s' % (goal, output))
        if not output:
            self.etb.engine.push_no_solutions(goal)
        else:
            for obj in output:
                self.log.info('obj: %s of type %s' % (obj, type(obj)))
                if isinstance(obj, terms.Claim):
                    pred = obj.literal.get_pred()
                    if pred == terms.mk_const('error'):
                        self.etb.engine.add_errors(goal, [obj])
                    else:
                        claim = terms.Claim(obj.literal, obj.reason)
                        self.etb.engine.add_claim(claim)
                else:
                    if isinstance(obj, dict):
                        obj = terms.Subst(obj)
                        fact = obj(goal)
                        self.log.info('fact: {0}'.format(fact))
                        # we add the ground goal to the claims of the engine
                        claim = terms.Claim(fact, model.create_external_explanation())
                        self.etb.engine.add_claim(claim)

    #  --------- API -------
    
    def reset(self):
        """Reset the state of the component."""
        with self._rlock:
            self.results.clear()

    # result cacheing
    def add_results(self, goal, claims):
        """
        Update the state: goal -> claims is asserted (maybe from a
        remote node). If goal already has registered results, this
        will do nothing.
        """
        if not goal.first_symbol().is_volatile() and len(claims) > 0:
            self.results[goal] = list(claims)                


    def add_tool(self, tool):
        """Add the handlers contained in the tool to self."""
        for name, obj in inspect.getmembers(tool):
            if getattr(obj, '_argspec', False):
                name = getattr(obj, '_predicate_name', name)
                symbol = terms.mk_term(name)
                self.set_handler(symbol, obj)
                if getattr(obj, '_volatile', False):
                    symbol.set_volatile()

    def set_handler(self, symbol, handler):
        """Add a handler for the given symbol"""
        assert symbol.is_const(), 'The symbol %s is not constant; set_handler unhappy' % symbol
        with self:
            assert symbol not in self._handlers, 'symbol %s already defined, set_handler unhappy' % symbol
            self._handlers[symbol] = handler
            self.log.debug('  predicate %s now interpreted by \'%s\'',
                              symbol, handler)

    def load_wrappers(self, wrapper_dir):
        """
        Load all wrappers from the given directory, and add the
        handlers they contain to self.

        We try to load all .py, except if they start with '__'.
        
        
        .. todo::
            Not loading rules yet!

        """

        self.log.info("Loading wrappers from {0}".format(wrapper_dir))
        if os.path.isdir(wrapper_dir):
            sys.path.append(wrapper_dir)
            files = os.listdir(wrapper_dir)
            for f in files:
                if f.startswith('__') or not f.endswith('.py'):
                    continue
                mod_name, _ = os.path.splitext(f)
                self._import_wrapper(mod_name)
        else:
            self.log.error('Not a valid directory: %s' % wrapper_dir)

    def is_interpreted(self, goal):
        """Checks whether the goal is interpreted."""
        self.log.debug('is_interpreted: %s', goal)
        pred = goal.first_symbol()
        return pred in self._handlers or \
            self.etb.networking.neighbors_able_to_interpret(pred) or \
            self.etb.networking.links_able_to_interpret(pred)

    def predicates(self):
        """Dict of predicate name -> argspec of the predicate"""
        preds = {}
        with self:
            for k,v in self._handlers.iteritems():
                preds[str(k.val)] = v._argspec
        return preds
    
    def has_been_interpreted(self, goal):
        """Checks whether this goal has already been interpreted."""
        if goal.first_symbol().is_volatile():
            return False
        return goal in self.results

    def handler_is_async(self, goal):
        method = self._get_handler(goal)
        return getattr(method, '_async', True)

    def interpret_goal_somewhere(self, goal, engine):
        """
        The engine doesn't know about the etb node, so we pass on the request for it.
        The engine argument is ignored.
        """
        self.etb.interpret_goal_somewhere(goal)

    def interpret(self, goal, sync=False):
        """Schedule goal to be interpreted, either now or later."""
        self.log.debug('interpreted: %s', goal)
        handler = self._get_handler(goal)
        if self.has_been_interpreted(goal):
            claims = tuple(self.results[goal])
            self.add_results(goal, claims)
        elif self.handler_is_async(goal):
            def task(etb, goal=goal, handler=handler):
                self._interpret(goal, handler)
            self.etb.long_pool.schedule(task)
        else:
            self._interpret(goal, handler)

    def interpreted_predicates(self):
        """Fresh list of which predicates (symbols) are interpreted
        by this component."""
        with self:
            ans = self._handlers.keys()
        return ans

