from abc import ABCMeta, abstractmethod
import beanstalkc
import json
from sys import exc_info, stdout
import logging
from os import getenv
import atexit
from argparse import ArgumentParser


RESULT_OK = 101
RESULT_CORRUPT = 102
RESULT_MUMBLE = 103
RESULT_DOWN = 104
RESULT_CHECKER_ERROR = 110


class FemidaLogger(logging.Logger):
    def __init__(self, name, worker_num):
        super(FemidaLogger, self).__init__(name)
        self.worker_num = worker_num

    def log(self, lvl, msg, *args, **kwargs):
        opts = dict(worker_num=self.worker_num)
        if 'extra' in kwargs:
            kwargs['extra'].update(opts)
        else:
            kwargs['extra'] = opts
        super(FemidaLogger, self).log(lvl, msg, *args, **kwargs)

    def debug(self, msg, *args, **kwargs):
        self.log(logging.DEBUG, msg, *args, **kwargs)

    def info(self, msg, *args, **kwargs):
        self.log(logging.INFO, msg, *args, **kwargs)

    def warning(self, msg, *args, **kwargs):
        self.log(logging.WARNING, msg, *args, **kwargs)

    def error(self, msg, *args, **kwargs):
        self.log(logging.ERROR, msg, *args, **kwargs)

    def critical(self, msg, *args, **kwargs):
        self.log(logging.CRITICAL, msg, *args, **kwargs)

    def exception(self, msg, *args, **kwargs):
        self.log(logging.ERROR, msg, *args, **kwargs)


class FemidaChecker(object):
    """Checker interface. You should implement _push and _pull methods"""
    __metaclass__ = ABCMeta

    def __init__(self):
        parser = ArgumentParser()
        parser.add_argument('--worker-num', metavar='worker_num', type=int,
                            help='worker process number')
        args = parser.parse_args()
        self.worker_num = args.worker_num
        self.setup_logging()

    def setup_logging(self):
        console_handler = logging.StreamHandler(stdout)
        if self.worker_num is not None:
            log_format = '[%(asctime)s] Worker %(worker_num)d - '\
                '%(levelname)s - %(message)s'
        else:
            log_format = '[%(asctime)s] - %(levelname)s - %(message)s'
        formatter = logging.Formatter(log_format)
        console_handler.setFormatter(formatter)
        self.logger = FemidaLogger(__name__, self.worker_num)
        self.logger.setLevel(logging.INFO)
        self.logger.addHandler(console_handler)

    def push(self, endpoint, flag_id, flag):
        result = (RESULT_CHECKER_ERROR, flag_id)
        try:
            result = self._push(endpoint, flag_id, flag)
        except:
            self.logger.exception('An exception occurred', exc_info=exc_info())
        return result

    def pull(self, endpoint, flag_id, flag):
        result = RESULT_CHECKER_ERROR
        try:
            result = self._pull(endpoint, flag_id, flag)
        except:
            self.logger.exception('An exception occurred', exc_info=exc_info())
        return result

    @abstractmethod
    def _push(self, endpoint, flag_id, flag):
        """Push <flag> into the <endpoint> service. Return tuple containing
        result from one of RESULT_* constants and flag_id"""
        pass

    @abstractmethod
    def _pull(self, endpoint, flag_id, flag):
        """Check if the flag that can be pulled from <endpoint> service by
        using some data in <flag_id> equals given <flag>. Return result from
        one of RESULT_* constants"""
        pass

    def run(self):
        """This method will be final. Don't override it"""
        host, port = getenv('BEANSTALKD_URI').split(':')
        beanstalk = beanstalkc.Connection(host=host, port=int(port))
        self.logger.info('Established connection with beanstalk server')

        @atexit.register
        def close_beanstalk():
            beanstalk.close()
            self.logger.info('Closed connection to beanstalk server')

        beanstalk.watch(getenv('TUBE_PUSH'))
        beanstalk.watch(getenv('TUBE_PULL'))

        while True:
            job = beanstalk.reserve()
            try:
                data = json.loads(job.body)
                tube = job.stats()['tube']
                service_id, operation = tube.split('.')
                if operation == 'push':
                    status, flag_id = self.push(data['endpoint'],
                                                data['flag_id'],
                                                data['flag'])
                    res = dict(status=status,
                               flag=data['flag'],
                               flag_id=flag_id,
                               endpoint=data['endpoint'])
                    s = 'PUSH flag {0} to {1}: result {2}, flag_id {3}'
                    self.logger.info(s.format(data['flag'],
                                              data['endpoint'],
                                              res['status'],
                                              res['flag_id']))
                    beanstalk.use(getenv('TUBE_REPORT_PUSH'))
                    beanstalk.put(json.dumps(res))
                elif operation == 'pull':
                    status = self.pull(data['endpoint'],
                                       data['flag_id'],
                                       data['flag'])
                    res = dict(status=status,
                               flag=data['flag'],
                               endpoint=data['endpoint'])
                    s = 'PULL flag {0} from {1} with flag_id {2}: result {3}'
                    self.logger.info(s.format(data['flag'],
                                              data['endpoint'],
                                              data['flag_id'],
                                              res['status']))
                    beanstalk.use(getenv('TUBE_REPORT_PULL'))
                    beanstalk.put(json.dumps(res))
            except:
                self.logger.exception('An exception occurred',
                                      exc_info=exc_info())
            job.delete()
