import codecs, base64
import logging
import os, platform, threading
import uuid
import traceback
from Queue import Queue, Empty

import datalog.engine
import git_interface
import interpret_state

import networking
import terms
import utils
import wrapper


class ETBConfig(object) :
    """Config wrapper with default values"""

    DEFAULT_PORT = 26532
    DEFAULT_WRAPPER_DIR = 'wrappers_enabled'
    DEFAULT_CONFIG_FILE = 'etb_conf.ini'
    DEFAULT_GIT_DIR = 'etb_git'
    DEFAULT_LOGIC_FILE = 'logic.json.bz2'

    def __init__(self, config):
        self.logic_file = os.path.abspath(config.get('logic_file',
                                                     self.DEFAULT_LOGIC_FILE))
        self.rule_files = config.get('rule_files', None)
        if self.rule_files is not None:
            self.rule_files = [os.path.abspath(f) for f \
                               in self.rule_files.split(',')]
        self.wrappers_dir = config.get('wrappers_dir', self.DEFAULT_WRAPPER_DIR)
        self.git_dir = os.path.abspath(config.get('git_dir',
                                                  self.DEFAULT_GIT_DIR))
        self.port = config.get('port', self.DEFAULT_PORT)
        self.cron_period = config.get('cron_period', networking.PING_FREQUENCY)

class ETB(object) :
    '''
    An ETB node.

    We have:
    - Two pools of threads for running tasks
    - an event thread
    - a thread serving xmlrpc requests

    The node maintains:
    - a logic state (claims + rules)
    - an interpret state (tool wrappers)
    - a git repository
    '''

    def __init__(self, config={}):
        self.config = ETBConfig(config)
        #logging.basicConfig(format="%(asctime)s %(levelname)s %(name)s %(process)d/%(threadName)s: %(message)s")
        self.log = logging.getLogger('etb')

        self._rlock = threading.RLock()

        # Interpret state
        self.interpret_state = interpret_state.InterpretState(self)   
        self.interpret_state.load_wrappers(self.config.wrappers_dir)  

        # New engine
        self.engine = datalog.engine.Engine(self.interpret_state, self.config.logic_file) 
        for rule_file in (self.config.rule_files or []):
            self.engine.load_rules(rule_file)


        # We set up the ETB filesystem and move to its root
        # be careful to do this after the wrappers have been imported,
        # because importing can only be done with relative paths
        self.git = git_interface.ETBGIT(self.config.git_dir, self.log)
        os.chdir(self.git.git_dir)

        # to run tasks periodically (also used by networking)
        self.cron = utils.CronJob(self, period=self.config.cron_period)
        # self.cron.onIteration.add_handler(self.update_done_queries)

        # different threads/thread pools
        self.short_pool = utils.ThreadPool(self)     # for quick tasks
        self.long_pool = utils.ThreadPool(self)      # for long running tasks
        self.task_worker = TaskWorker(self, daemon=True)    # main etb thread

        # networking component using xml-rpc
        self.networking = networking.Networking(self, self.config.port)

        # goals we subscribed to
        self.subscriptions = set()

        # queries
        self._queries = {}
        self._done_queries = {}
        self.active_local_queries = set()
        self.active_remote_queries = {}

        import atexit
        atexit.register(self.stop)

    def stop(self):
        """Stop all components and all threads. May block if some
        thread does not stops gracefully.
        """
        self.log.info("stop ETB instance...")
        # save state
        self.log.debug("save ETB state")
        self.engine.save_to_default_file() 
        # stop components
        self.log.debug("stop cron thread")
        self.cron.stop()
        self.log.debug("stop short tasks pool")
        self.short_pool.stop()
        self.log.debug("stop long tasks pool")
        self.long_pool.stop()
        self.log.debug("stop networking")
        self.networking.stop()
        self.log.debug("stop main ETB task worker")
        self.task_worker.stop()

    @property
    def id(self):
        return self.networking.id

    def __repr__(self):
        return "ETB(id={0})".format(self.id)

    def __enter__(self):
        """open lock context"""
        self._rlock.acquire()

    def __exit__(self, t, v, tb):
        """close lock context"""
        self._rlock.release()

    def __eq__(self, other):
        return isinstance(other, ETB) and self.id == other.id

    def __hash__(self):
        return hash(self.id)

    def add_tool(self, tool):
        self.interpret_state.add_tool(tool)

    def add_rule(self, rule):
        self.engine.add_rule(rule)

    def create_query(self, goals):
        '''
        Create a new query.
        '''
        query = goals                       
        qid = uuid.uuid4().get_hex()
        self.log.info("create_query: %s %s" % (query, qid))
        with self:
            self._queries[qid] = query
            self.schedule_query(qid, query)
        return qid

    def create_proof_query(self, goals):
        '''
        Create a new proof query.
        '''
        query = goals
        qid = uuid.uuid4().get_hex()
        with self:
            self._queries[qid] = query
            self.schedule_query(qid, query)
        return qid
    
    def schedule_query(self, qid, query):
        def task(etb, qid=qid, query=query):
            for goal in query:
                self.engine.add_goal(goal) 
            with etb:
                etb._done_queries[qid] = query
                del etb._queries[qid]
        self.long_pool.schedule(task)

    def get_query(self, qid):
        """
        Get the query object associated with this query id. 
        Returns None if not found.
        """
        with self:
            return self._queries.get(qid, None) or \
                self._done_queries.get(qid, None)

#    #old version. needs a port to etb3
#    def query_derivation(self, qid, f):
#        self.log.info("in query_derivation")
#        query = self.get_query(qid)
#        if not query:
#            return False
#        claims = list(query.answer_facts())
#        return self.engine.claims_to_dot(claims, f)

#    #old version. needs a port to etb3
#    def query_proof(self, qid, f):
#        self.log.info("in query_proof")
#        query = self.get_query(qid)
#        if not query:
#            return False
#        claims = list(query.answer_facts())
#        return self.engine.claims_to_dot(claims, f, proof=True)

    #old version. needs a port to etb3
    def fact_explanation(self, fact, f):
        self.log.info("in query_explanation of etb")
        # only deal with queries that contain 1 goal
        self.engine.to_dot(fact, f)

    def update_predicates(self):
        self.engine.check_stuck_goals()

    def clear_claims_table(self):
        """Totally clear the content of the engine and interpret_state.
        All rules, claims, and already interpreted goals are erased.
        This is NOT reversible.
        """
        self.log.info('reset ETB state')
        self.engine.reset()
        self.interpret_state.reset()

    @property
    def queries(self):
        return self._queries.keys()

    @property
    def done_queries(self):
        return self._done_queries.keys()
    
    def query_answers(self, qid):
        """Returns the current list of answers for the given query. (list
        of substitutions).
        """
        self.log.info("in query_answers")
        query = self.get_query(qid)
        self.log.info("in query_answers %s of type %s" % (query, type(query)))
        goal = query[0]
        self.log.info("goal: %s of type %s" % (goal, type(goal)))
        substs =  self.engine.get_substitutions(goal)
        self.log.info("substitutions: %s" % substs)
        substs = sorted([ terms.dumps(a) for a in substs ])
        claims = self.engine.get_claims_matching_goal(goal)
        self.log.info("claims: %s" % claims)
        claims = sorted([ terms.dumps(c) for c in claims ])
        return { 'substs' : substs,  'claims' : claims }


    def all_claims(self):
        self.log.info("in all_claims")
        return self.engine.get_claims()

#    #old version. needs a port to etb3
#    def proofs(self, qid):
#        query = self.get_query(qid)
#        query.proof()
        
    @property
    def load(self):
        """Estimate load of this node (length of long_pool.queue?)"""
        return 42

    def error(self, msg) :
        "log the error and then fail (raise an exception)"
        self.log.error(msg)
        assert False, msg

    def interpret_goal_somewhere(self, goal):
        """Interpret the goal on some node, possibly this one."""
        pred = goal.first_symbol()
        self.log.info('Looking interpret %s somewhere.' % pred)
        candidates = self.networking.neighbors_able_to_interpret(pred)
        link = False
        if not candidates:
            candidates = self.networking.links_able_to_interpret(pred)
            link = True
            if not candidates:
                self.error("no node able to interpret goal {0}".format(goal))

        argspecstr = candidates[0].predicates[str(pred.val)]
        argspecs = wrapper.ArgSpec.parse(argspecstr)
        if len(argspecs) != len(goal.args):
            self.error(
                "Have %d argspecs, expect %d" % (len(argspecs), len(goal.args)))

        handles = []
        for (spec, arg) in zip(argspecs, goal.args) :
            if spec.kind == 'handle':
                handles.append((spec, arg))

        candidates = self.filter_candidates(candidates, goal)
        if not candidates:
            self.error("no node able to interpret goal {0}".format(goal))

        best_node = min(candidates, key=lambda n: n.load)

        if best_node.id == self.id :
            self.interpret_state.interpret(goal, sync=True)
            return

        argspecstr = best_node.predicates[str(goal.first_symbol().val)]
        argspecs = wrapper.ArgSpec.parse(argspecstr)
        new_id = uuid.uuid4().get_hex()
        self.active_remote_queries[new_id] = (goal, best_node, argspecs, goal)
        proxy = best_node.proxy
        
        if link:
            self.log.info('Sending %s to remote ETB on %s' % (goal.first_symbol(), best_node.id))
            proxy.interpret_goal_remotely(self.id, terms.dumps(goal), new_id)
        else:
            self.log.info('Asking %s to interpret %s.' % (best_node.id, goal.first_symbol()))
            proxy.interpret_goal(self.id, terms.dumps(goal), new_id)
        
        best_node.increment_load()

    def get_goals_dependencies(self, goal, remote_etb):
        for a in goal.get_args():
            if a.is_ground():
                try:
                    fileref = { 'file' : str(a.get_args()[terms.mk_term('file')]),
                                'sha1' : str(a.get_args()[terms.mk_term('sha1')]) }
                    content = remote_etb.get_file(terms.dumps(a))
                    if content:
                        content = base64.b64decode(content)
                        self.create_file(content, fileref['file'])
                    else:
                        self.log.error('Unable to get remote file: %s', fileref)
                except Exception as e:
                    # Non-fileref will go there
                    pass
            
    def process_interpreted_goal_answer(self, argspecs, goal, node, answer):
        output = []
        for subst in answer['substs']:
            if str(subst) == '{"__Subst": []}' :
                output.append( terms.Subst() )
            else :
                output.append( terms.loads(subst) )
        for c in answer['claims']:
            output.append( terms.loads(c) )
        self.interpret_state._process_output(goal, output)

    def filter_candidates(self, candidates, goal) :
        """Filter candidate nodes by handle information"""
        for arg in goal.args :
            if isinstance(arg, dict) and arg.has_key('etb') :
                candidates = [ c for c in candidates if c.id == arg['etb'] ]
        return candidates

    def get_file_from_somewhere(self, fileref):
        """
        Given a name and sha1 hash pair, which are not available locally,
        find the corresponding file and copy it to the current directory.
        """
        name = fileref['file']
        sha1 = fileref['sha1']
        contents, execp = self.networking.get_contents_from_somewhere(sha1)
        if contents is None:
            self.log.error('File [{0}, {1}] not found anywhere' . \
                                  format(name, sha1))
            raise
        self.create_file(contents, name, execp)

    #used when a wrapper returns a substitution to fetch the files mentioned there in
    def fetch_support(self, substitution):
        fileterm = terms.mk_term('file')
        sha1term = terms.mk_term('sha1')
        for _,v in substitution.get_bindings():
            if v.kind == terms.Term.MAP and fileterm in v.get_args():
                fileref = { 'file' : str(v.get_args()[fileterm]),
                            'sha1' : str(v.get_args()[sha1term]) }
                if not self.git.is_local(fileref):
                    self.get_file_from_somewhere(fileref)
            elif (v.kind == terms.Term.ARRAY and
                  all(x.kind == terms.Term.MAP and
                      fileterm in x.get_args() for x in v.get_args())):
                for x in v.get_args():
                    fileref = { 'file' : str(x.get_args()[fileterm]),
                                'sha1' : str(x.get_args()[sha1term]) }
                    if not self.git.is_local(fileref):
                        self.get_file_from_somewhere(fileref)

    def create_file(self, contents, filename, execp=False):
        git_name = self.git._make_local_path(filename)
        if platform.system() == 'Windows':
            git_name = '\\'.join(git_name.split('/'))
        else:
            git_name = '/'.join(git_name.split('\\'))
        ndir = os.path.dirname(git_name)
        if ndir != '' and not os.path.isdir(ndir):
            os.makedirs(ndir)
        self.log.debug('Creating %s' % git_name)
        with codecs.open(git_name, mode='wb', errors='ignore') as fd:
            fd.write(contents)
            fd.close()
        if execp:
            self.log.debug('Changing executable permission {}'.format(git_name))
            mode = os.stat(git_name).st_mode
            self.log.debug('Current mode is {}'.format(mode))
            #os.chmod(git_name, mode | stat.S_IXUSR)
            try:
                os.chmod(git_name, 0o755)
                self.log.debug('mode changed to {}'.format(0o755))
            except Error as err:
                self.log.debug('mode change problem {}'.format(err))
                raise err
        return self.git.register(filename)
        
def debug_level_value(value):
    LEVELS = {'debug': logging.DEBUG,
              'info': logging.INFO,
              'warning': logging.WARNING,
              'error': logging.ERROR,
              'critical': logging.CRITICAL}
    if value in LEVELS:
        return LEVELS[value]
    else:
        assert False, 'Invalid debug level: %s' % value

class TaskWorker(threading.Thread):
    """
    The single thread responsible for doing inferences
    and managing internal structures of ETB
    """
    def __init__(self, logger, daemon=True):
        threading.Thread.__init__(self)
        self._stop = False
        self.daemon = daemon
        self._queue = Queue()
        self.start()
        self.log = logger

    def run(self):
        """The main loop of processing tasks"""
        while not self._stop:
            try:
                task = self._queue.get(timeout=.2)
                task()
            except Empty:
                pass
            except Exception as e:
                self.log.warning('error in event thread: {0}'.format(e))
                traceback.print_exc()

    def stop(self):
        """Stop the thread and wait for it to terminate"""
        self._stop = True
        self.join()

    def schedule(self, task):
        """Schedule the task to be processed later."""
        self._queue.put(task)
