"""
Base functionality for NetLogger 'action' modules.

The intent of these is to perform periodic actions like
rolling over the database, summarizing data, or refreshing views.
"""

## Imports

import imp
import pickle
import os
import re
import sys
import time
import traceback # for debugging

from netlogger.nllog import get_logger, DoesLogging
from netlogger.analysis import loader
from netlogger.util import ConfigError, DBConnectError, NL_HOME
from netlogger.util import IncConfigObj

## Logging

def get_log():
    return get_logger("actions.base")

## Classes

class BaseAction(DoesLogging):
    """Base class for action modules.

    Subclasses should override execute().
    Subclasses may override shouldExecute().
    """
    def __init__(self, conn=None, state=None, param={}):
        """Create with a database connection object 'conn',
        the persistent information (if any) that the module 
        saved from the last run, and any parameters.
        """
        DoesLogging.__init__(self)
        self._conn = conn
        self._param = param
        self._state = state
        self._timestamp = time.time()
        
    def execute(self):
        """Perform the action.

        Parameters passed in by configuration file are available
        as self._param.

        The return value is ignored. Any subclass of Exception will
        be caught by the caller.
        """
        raise NotImplementedError(
            "This subclass of BaseAction did not properly "
            "override the execute method")

    def finish(self, status=0):
        """Anything necessary to do after the action executes, before
        it exits.

        Subclasses can extend this, but should still call this
        base method unless there is a good reason not to.
        """
        # Set the status code
        self._state.last_status = status
        # Set the last-run time to the timestamp from the
        # constructor. This is therefore guaranteed to be before
        # the time of the actual execution.
        self._state.last_run = self._timestamp

    def getInfo(self):
        """Get persistent state information.
        """
        return self._state.info

    def getConnection(self):
        """Get the database connection object.
        """
        return self._conn

    def getParameters(self):
        """Get the action's parameters, as a dictionary.
        """
        return self._param

    def setInfo(self, value):
        """Set a new value for persistent state information.
        """
        self._state.info = value

    def getLastRun(self):
        """Get the time this action was last run, in
        seconds since the epoch (like time.time()).
        Will be zero(0) if it was not run before.
        """
        return self._state.last_run

    def setLastRun(self, time):
        """Set the time this action was last run, in
        seconds since the epoch (like time.time()).
        """
        self._state.last_run = time

    def shouldExecute(self, sched, last_run):
        """Decide whether to run now based on the timestamp
        in 'last_run' and the value of sched, a Schedule instance.

        Default behavior is to simply let the schedule decide.

        Returns True or False.
        """
        now = time.time()
        return sched.match(now, previous=last_run)

class SavedStateFile(DoesLogging):
    """Saved state of modules, in a file.
    """
    def __init__(self, path):
        """Create with the path to the state file.

        Will raise an IOError if path isn't a readable file.
        """
        DoesLogging.__init__(self)
        # create file if it doesn't exist
        if not os.path.exists(path):
            open(path, "w")
        self._path = path
        self._load()

    def getState(self, module_name):
        """Get state for module 'module_name'. If there is none,
        create a new one.
        """
        if self._sdata.has_key(module_name):
            result = self._sdata[module_name]
        else:
            result = SavedState()
            self._sdata[module_name] = result
        return result

    def save(self):
        """Save current contents of state to a file.
        """
        f = file(self._path, "w")
        self.log.debug("state.save", file=self._path)
        pickle.dump(self._sdata, f)

    def _load(self):
        """Read pickled dictionary {'module-name': State-instance}
        """
        f = file(self._path)
        try:
            self._sdata = pickle.load(f)
        except EOFError:
            self._sdata = { }
        self.log.debug("state.load", file=self._path)

class SavedState:
    """Saved state of a single module.
    """
    def __init__(self):
        self.last_run = 0
        self.info = None
        self.last_status = 0

    def __str__(self):
        return str(self.info)
    __repr__ = __str__

class ActionFactory:
    """Create new instances of the Action class for
    a given module, using shared values for the DB connection
    and state.
    """
    class InitializationError(Exception): pass

    def __init__(self, db_cfg, saved_state_file):
        self._cfg = db_cfg
        self._ssf = saved_state_file

    def create(self, mod, mod_name, mod_param, connect=True):
        """Create new {mod_name}.Action instance. The module
        object is 'mod', its name is 'mod_name', and the parameters
        for it are in 'mod_param'.

        Load the state for each instance from the state file given
        to the constructor.

        If connect is True, the default, then create and give each 
        instance a new connection to the database,
        using the configuration given to the constructor.

        Return pair (new-instance, last-run-time)
        """
        state = self._ssf.getState(mod_name)
        try:
            clazz = getattr(mod, 'Action')
            if connect:
                conn = self._connect()
            else:
                conn = None
            inst = clazz(conn=conn, param=mod_param, state=state)
            result = (inst, state.last_run)
        except Exception, E:
            #traceback.print_exc()
            raise ActionFactory.InitializationError(E)
        return result

    def _connect(self):
        """Connect to the database, using information from a parsed
        configuration.

        Exceptions raised are either ConfigError or DBConnectError.

        Return value is a connection object.
        """
        db = self._cfg # easier to read
        # Load the appropriate module
        try:
            db_mod = loader.DB_MODULES[db.scheme]
        except KeyError:
            raise ConfigError("Cannot find database module "
                              "for scheme '%s'" % db.scheme)     
        # Connect to database, raises DBConnectError on failure
        conn = loader.connect(dbmod=db_mod, conn_kw=db.params, 
                              dsn=db.dsn, dbname=db.name)
        return conn

class Schedule(DoesLogging):
    """Description of how often (when) to run something.
    """
    # Schedule variations
    IS_CRON, IS_INTERVAL = 1, 2
    # Regular expressions for Format 1
    CRON_ITEM = "(\d+|\*)"
    CRON_EXPR = "\s+".join([CRON_ITEM] * 5)
    CRON_RANGES = ((0,59), (0,23), (1,31), (1,12), (1,7))
    # Regular expressions for Format 2
    INTERVAL_EXPR = "(\d+)\s+(\w+)" # e.g. '12  days'
    # Interval -> seconds mapping
    INTERVAL_KEYS = ('seconds', 'minutes', 'hours', 'days', 'weeks')
    INTERVALS = { 'seconds' : 1, 'minutes' : 60, 'hours' : 60 * 60,
                  'days' : 60 * 60 * 24, 'weeks' : 60 * 60 * 24 * 7 }
    def __init__(self, spec):
        """Initialize from crontab-like format or simple interval.

        Format 1: min(0-59) hour(0-23) dom(1-31) month(1-12) dow(1-7)
        '*' is a valid value. No ranges or steps.
        Format 2: <N> <min,hour,day,month>

        Raises ValueError if the format is not valid.
        """
        DoesLogging.__init__(self)
        cron_mobj = re.match(self.CRON_EXPR, spec)
        ival_mobj = re.match(self.INTERVAL_EXPR, spec)
        if cron_mobj:
            try:
                self.values = self._parseCron(cron_mobj)
            except ValueError, E:
                raise ValueError("Error parsing cron-like schedule '%s': %s" %
                                 (spec, E))
            self.sched_type = self.IS_CRON
        elif ival_mobj:
            try:
                self.values = self._parseInterval(ival_mobj)
            except ValueError, E:
                raise ValueError("Error parsing interval schedule '%s': %s" %
                                 (spec, E))
            self.sched_type = self.IS_INTERVAL
        else:
            raise ValueError("Specified schedule invalid: %s" % spec)

    def _parseCron(self, mobj):
        """Parse match object for cron-like format and return values.

        Will raise ValueError if the format is wrong.
        """
        groups = mobj.groups()
        values = [ ]
        for i, item in enumerate(groups):
            if item == '*':
                values.append(-1)
            else:
                try:
                    v = int(item)
                except ValueError, E:
                    raise ValueError("item (%d) must be an integer or '*'" % 
                                      i+1)
                start, end = self.CRON_RANGES[i]
                if (v < start) or (v > end):
                    raise ValueError("item (%d) not in range (%d,%d)" % (
                            i+1, start, end))
                values.append(v)
        return values

    def _parseInterval(self, mobj):
        """Parse match object for interval format and return values.

        Will raise ValueError if the format is wrong.
        """
        groups = mobj.groups()
        try:
            n = int(groups[0])
        except ValueError:
            raise ValueError("Schedule interval is not an integer")
        unit = groups[1]
        try:
            key = self._findPrefix(self.INTERVAL_KEYS, unit)
        except KeyError:
            s = "Known units: %s" % (', '.join(self.INTERVAL_KEYS))
            raise ValueError("Unknown interval unit, '%s'. %s" % (unit, s))
        mult = self.INTERVALS[key]
        return [ n * mult ]

    def _findPrefix(self, strings, key):
        """Find item in strings for which key is a prefix.

        Raise a KeyError if there is no match or it is ambiguous.

        Return the full name of the matched item (from 'strings').
        """
        matched = False
        for s in strings:
            if s.startswith(key):
                if matched:
                    raise KeyError("Ambiguous prefix: '%s' and '%s'" % 
                                   (matched, s))
                else:
                    matched = s
        if matched == False:
            raise KeyError("Prefix '%s' does not match any of (%s)." % (
                    key, ', '.join(strings)))
        return matched

    def match(self, timestamp, previous=0):
        """Return whether the current time matches the schedule.
        The previous match time may also be considered (for intervals).

        Returns True or False
        """
        if self.sched_type == self.IS_INTERVAL:
            delta = timestamp - previous
            return (delta >= self.values[0])
        else:
            return self._matchCron(timestamp)

    def _matchCron(self, t):
        """Check if time matches cron-style time specification.
        
        Return True or False.
        """
        (year, mon, day, hour, min, sec, wday, yday, isdst) = \
            time.localtime(t)
        wday += 1 # make Monday == 1 .. Sunday == 7
        # make a list from current time, in same order as values
        current = (min, hour, day, mon, wday)
        # check each target value against the list
        match = True
        for (tgt,cur) in zip(self.values, current):
            self.log.debug("cron.test", tgt=tgt, cur=cur)
            if tgt == -1: # any value will do
                continue
            if tgt != cur:
                match = False
                break
        return match

class Configuration:
    """Configuration for a group of actions.

    [global]
    home = /path/to/nl_home (default in env. or CWD)
    state = /path/to/state_file OR state_file {rel. to $home}
    modules = /path/to/modules_dir OR modules_dir {rel. to $home}

    [database]
    uri = mysql://localhost
    db = whatever
    user = foo
    passwd = foo
    <param> = <value>
    ...

    [<module1>]
    <param> = <value>
    ...
    SCHEDULE = # see Schedule docs
    [<module2>]
    <param> = <value>
    SCHEDULE = * * * * 0
    ...etc..
    """
    WHEN_KEYWORD = 'SCHEDULE'

    DB_SECTION = 'database' # name of [database] section
    DBNAME_KW = 'database' # keyword in [database] for the database name 
    URI_KW = 'uri' # keyword in [database] for the scheme://host URI
    DB_REQUIRED = (DBNAME_KW, URI_KW) # these must be defined

    GLOBAL_SECTION = 'global' # name of [global] section
    HOME_KW = 'home'
    STATE_KW = 'state'
    STATE_DEFAULT = "state"
    MOD_KW = 'modules'
    MOD_DEFAULT = "modules"

    def __init__(self, path):
        """Initialize with path to configuration file.

        Parse the file into two dictionaries, self.module_params and
        self.when. Both have as keys the names of the modules.
        The value for module_params is a dictionary of parameter values.
        The value for when is an instance of Schedule.

        Raises a ConfigError if the file is not readable or
        there is a parse error.
        """
        _openConfig(path)
        try:
            cfg = IncConfigObj(path)
        except configobj.ConfigObjError, E:
            raise ConfigError(E)
        self._parseGlobal(cfg)
        self._parseDatabase(cfg)
        self._parseModules(cfg)

    def _parseGlobal(self, cfg):
        """Parse [global] section.
        """
        section = cfg.get(self.GLOBAL_SECTION)
        # base directory
        if section and section.has_key(self.HOME_KW):
            base_dir = section.get(self.HOME_KW)
        else:
            base_dir = os.getenv(NL_HOME, os.getcwd())
        # path to modules
        if section and section.has_key(self.MOD_KW):
            mod_path = section.get(self.MOD_KW)
        else:
            mod_path = self.MOD_DEFAULT        
        # path to state file
        if section and section.has_key(self.STATE_KW):
            state_path = section.get(self.STATE_KW)
        else:
            state_path = self.STATE_DEFAULT
        # set state, module paths from absolute or relative to base_dir
        if os.path.abspath(mod_path) == mod_path:
            self.module_path = mod_path
        else:
            self.module_path = os.path.join(base_dir, mod_path)
        if os.path.abspath(state_path) == state_path:
            self.state_path = state_path
        else:
            self.state_path = os.path.join(base_dir, state_path)

    def _parseDatabase(self, cfg):
        """Parse [database] section
        """
        section = cfg.get(self.DB_SECTION)
        if section is None:
            raise ConfigError("Missing [%s] section" % self.DB_SECTION)
        self.db_params = { }
        for k in self.DB_REQUIRED:
            if not section.has_key(k):
                raise ConfigError("Missing '%s' keyword in [%s] section" % (
                        k, self.DB_SECTION))
        self.db = DatabaseConfiguration(section)

    def _parseModules(self, cfg):
        """Parse the action modules.
        """        
        self.module_params = { } # modules and parameters
        self.when = { } # when to run, keys are module names
        # pull out and parse schedule for each module
        for mod_name, section in cfg.items():
            if mod_name in (self.DB_SECTION, self.GLOBAL_SECTION):
                continue
            self.module_params[mod_name] = { }
            had_schedule = False
            for k,v in section.items():
                if isinstance(v,dict):
                    raise ConfigError("Sub-section '%s' in module "
                                      "section '%s' is not allowed." % (
                            k, mod_name))
                if k == self.WHEN_KEYWORD:
                    # duplicate keyword will be caught at a lower level
                    # if had_schedule: ...
                    try:
                        self.when[mod_name] = Schedule(v)
                        had_schedule = True
                    except ValueError,E:
                        raise ConfigError("Invalid schedule for module "
                                          "%s: %s" % (mod_name, E))
                else:
                    self.module_params[mod_name][k] = v
            if not had_schedule:
                raise ConfigError("No schedule for module '%s'" % mod_name)

class DatabaseConfiguration:
    """Separate class for the database configuration section.
    """
    def __init__(self, section):
        """Create attributes from the section content.

        Raises a ConfigError if values are not correct.
        """
        self.params = section.copy()
        # Extract database URI into (scheme, DSN)
        uri = self.params[Configuration.URI_KW]
        parts = uri.split("://", 1)
        if len(parts) != 2:
            raise ConfigError("Invalid URI '%s'" % uri)
        self.scheme, self.dsn = parts
        del self.params[Configuration.URI_KW]
        # Extract database name from keywords
        try:
            self.name = self.params[Configuration.DBNAME_KW]
        except KeyError:
            raise ConfigError("Internal error: No '%s' in "
                              "keywords: %s" % (
                    Configuration.DBNAME_KW, self.db_params))
        del self.params[Configuration.DBNAME_KW]

## Functions

def _openConfig(config_file):
    """Common code for connect() and Configuration constructor.
    """
    # try to open the file
    try:
        open(config_file)
    except IOError, E:
        raise ConfigError("[Errno %d] Configuration file '%s': %s" % (
            E.errno, E.filename, E.strerror))

def getModuleActions(mod_path=None, params={}, schedules={}, 
                     action_factory=None):
    """Return actions matching dictionary in 'params', 
    from modules found at 'mod_path', initializing them with the
    matching Schedule instance in 'schedules'.

    Use the ActionFactory instance given as 'action_factory' to 
    create the Action class instance from the module.

    Return, using the generator API, a new Action instance
    for each module in the cfg.module_params.keys() that
    should run now.

    Raises ConfigError if the module path doesn't exist.
    Raises ImportError if module can't be imported
    """
    log = get_log()
    # check that directory in mod_path exists
    if not os.path.exists(mod_path):
        raise ConfigError("Module path '%s' not found" % mod_path)
    if not os.path.isdir(mod_path):
        raise ConfigError("Module path '%s' is not a directory" % mod_path)
    # split mod_path into {mod_dir}/mod_name
    def strip_slashes(x): 
        while x and x[-1] == '/':
            x = x[:-1]
        return x
    parts = os.path.split(strip_slashes(mod_path))
    if not parts[0]:
        raise ImportError("Module path '%s' must be an absolute path" %
                          mod_path)
    # Note: because of strip_slashes(), parts[1] must be non-empty
    path, name = parts
    # save sys.path
    sys_path = sys.path[:]    
    try:
        # temporarily modify sys.path to be module's path
        sys.path = [path]
        # load 'name' as parent module
        log.debug("module.import.start", name=name, sys__path=sys.path)
        try:
            info = imp.find_module(name)
            parent_mod = imp.load_module(name, *info)
        except ImportError,E:
            log.error("module.import.end", status=-1, msg=E)
            raise ImportError("Cannot import module '%s' on path '%s'"
                              % (name, sys.path))
        log.debug("module.import.end", name=name, status=0)
    finally:
        # reset sys.path
        sys.path = sys_path
    # Load, initialize, and yield each runnable action
    for run_name, run_param in params.items():
        info = imp.find_module(run_name, parent_mod.__path__)
        mod = imp.load_module(run_name, *info)
        try:
            schedule = schedules[run_name]
            action, last_run = action_factory.create(mod, run_name, 
                                                     run_param)
        except ActionFactory.InitializationError, E:
            log.error("init", module=run_name, msg=E)
            raise ConfigError("Initialization error: %s" % E)
        if action.shouldExecute(schedule, last_run):
            yield(run_name, action)
