# Copyright (C) 2007, 2010 Ian Zimmerman <itz@buug.org>

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the conditions spelled out in
# the file LICENSE are met.

from __future__ import with_statement , absolute_import

import os
import sys
import time
import logging
import re
from fnmatch import fnmatch
import subprocess
from subprocess import PIPE
import signal
from StringIO import StringIO
import shutil
import tempfile
import email
import email.Utils
import shelve
from .lock import unlocking
from contextlib import closing , contextmanager , nested
from socket import gethostname

# exceptions
class Error (Exception):
    """Base class for exceptions specific to the sortmail module."""

class ParseError (Error):
    """Exception raised when an ill-formed message is passed to the constructor."""

class DBFormatError (Error):
    def __init__ (self, desired, recorded):
        pass

class DBVersionError (Error):
    def __init__ (self, oldest_compatible, actual, newest_compatible):
        pass

class BadMatchKindError (Error):
    """Exception raised when an invalid match kind argument is passed to Msg.header_glob.
    """
    def __init__ (self, arg):
        """The exception constructor; arg is the invalid argument passed to Msg.header_glob.
        """
        self.arg = arg

# internals
_PIPE_BUF = 512
_max_body = _PIPE_BUF * 8
_ONE_DAY = 60 * 60 * 24
_ONE_WEEK = 60 * 60 * 24 * 7
_MY_VERSION = 20120108
_LAST_COMPAT = 20100801         # remember to update this!

_re_from = re.compile ('From[ \t]+([^ \t]+)')
_re_header = re.compile ('([^\x00-\x1f\x7f-\xff :]+):[ \t]*')
_re_dest_tag = re.compile ('(?is)(?:(?:original-)?(?:resent-)?(?:to|cc|bcc)|(?:x-envelope|apparently(?:-resent)?)-to)$')
_re_sender_tag = re.compile ('(?is)(?:(?:resent-)?sender|resent-from|return-path)$')
_id_headers = ('from', 'date', 'message-id')
range_3 = range (3)
range_256 = range (256)

# 99% of the time, the receiving host places the sender's ip address in
# square brackets as it should, but every once in awhile it turns up in
# parens.  Yahoo seems to be guilty of this minor infraction:
#   Received: from unknown (66.218.66.218)
#       by m19.grp.scd.yahoo.com with QMQP; 19 Dec 2003 04:06:53 -0000
_re_received_ip = re.compile(r'[([]((\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3}))[])]')

_ord_A = ord ('A')
_ord_Z = ord ('Z')

# check if an ascii uppercase character occurs in s, independently of locale
def _my_islower (s):
    for c in s:
        if _ord_A <= ord (c) and ord (c) <= _ord_Z:
            return False
    else:
        return True

def _ascii_lower (i):
    if _ord_A <= i and i <= _ord_Z:
        return chr (ord ('a') + (i - ord ('A')))
    else:
        return chr (i)

_ascii_lower_table = ''.join ([_ascii_lower (i) for i in range_256])

def _atsign_to_dot (i):
    if i == ord ('@'):
        return '.'
    else:
        return chr (i)

_atsign_to_dot_table = ''.join ([_atsign_to_dot (i) for i in range_256])

def _read_message_from_fh (fh, max_body = _max_body):
    size = 0
    envelope_from = None
    headers = []
    body = ''
    tf = None
    line = fh.readline ()
    size += len(line)
    mobj = _re_from.match (line)
    if mobj is not None:
        envelope_from = mobj.group (1)
        line = fh.readline ()
        size += len(line)
    mobj = _re_header.match (line)
    while mobj is not None:
        tag = mobj.group (1)
        content = line [mobj.end () : ]
        line = fh.readline ()
        size += len(line)
        while line.startswith (' ') or line.startswith ("\t"):
            content += line
            line = fh.readline ()
            size += len(line)
        headers.append ((tag, content))
        mobj = _re_header.match (line)
    if line != '' and line != "\n":
        logging.getLogger ('sortmail').error ('Invalid email format')
        raise ParseError
    body = fh.readlines (max_body)
    size += sum([len(l) for l in body])
    morebody = fh.readline ()
    size += len(morebody)
    if (morebody != ''):
        tf = tempfile.TemporaryFile ()
        tf.write (morebody)
        shutil.copyfileobj (fh, tf, _PIPE_BUF)
        tf.seek(0, 2)
        size += tf.tell()
        tf.seek(0, 0)
    return (envelope_from, headers, body, tf, size)

def _read_message_from_string (s, max_body = _max_body):
    return _read_message_from_fh (StringIO (s), max_body)

def _has_address (address, content):
    parsed = email.Utils.getaddresses ([content])
    for p in parsed:
        _, addy = p
        if address.translate (_ascii_lower_table) == addy.translate (_ascii_lower_table):
            return p
    else:
        return None

# exports
header_cxform = list ()

def with_header_cxform (cxforms, fun):
    """Call the function like object fun with with the header transform list
    locally set to cxforms.  The previous list of transforms is restored when
    with_header_cxform returns or propagates an exception.
    """
    global header_cxform
    old_header_cxform = header_cxform
    header_cxform = cxforms
    result = None
    try:
        result = fun ()
    finally:
        header_cxform = old_header_cxform
    return result

# the IDShelf class is now stubbed out, because the functionality is implemented
# using python standard library shelves.  See Msg.deliver().
        
class IDShelf(object):
    """Class representing a Berkeley DB file to store unique message identifiers
    and check incoming messages for duplicates.
    """
    def __init__ (self, key, dbfile, id_headers, clean_period, ttl):
        pass

    def probe (self):
        return False

    def record (self):
        return self
    

class Msg(object):
    """The main class, representing an incoming email message.
    The header lines are stored in self._headers which is a list of pairs (t, c)
    where t is the tag (e.g. "From") and c is the contents.  No folding or unfolding
    is done on the contents, it is stored exactly as it came in the message, possibly
    split into multiple physical lines, each but the first beginning with white space.
    The body is stored in general in two pieces: the first part as just a string in
    self._body, and the second part in a temporary file self._tf.  If the body is
    short enough there is no temporary file and self._tf is None.
    Another field of note is self._envelope_from, which can be initialized either
    from the message 'From ' line (if one exists), or explicitly from an argument
    to the constructor.
    """

    def _make_from_line (self):
        """Return a properly formatted 'From ' line for delivery to Unix mbox files."""
        stime = time.strftime ('%a %b %d %H:%M:%S %Y', time.localtime ())
        return "From %s %s\n" % (self._envelope_from, stime)

    def _print_to_fh (self, fh):
        """Output the message to the file fh.  No transformation is done on the message body,
        in particular no From escaping.
        """
        for (tag, content) in self._headers:
            fh.write ('%s: %s' % (tag, content))
        fh.write ('\n')
        for l in self._body:
            fh.write (l)
        if self._tf is not None:
            self._tf.seek (0)
            for l in self._tf:
                fh.write (l)
        return self

    def __str__ (self):
        """Return the message as a single string.  This returns exactly the same string as the
        one output to fh by _print_to_fh (fh).
        """
        fh = StringIO ('')
        self._print_to_fh (fh)
        s = fh.getvalue ()
        fh.close ()
        return s

    def _append_to_box (self, fn):
        """Append the message to the file named fn.  This is similar to _print_to_fh,
        but additionally brackets the message with a leading From line and a trailing
        newline, and does From escaping on the body.
        """
        with open (fn, 'a') as fh:
            if not self._test_only:
                fh.write (self._make_from_line ())
                for (tag, content) in self._headers:
                    fh.write ('%s: %s' % (tag, content))
                fh.write ('\n')
                for l in self._body:
                    if l.startswith ('From '):
                        fh.write ('>')
                    fh.write (l)
                if self._tf is not None:
                    self._tf.seek (0)
                    for l in self._tf:
                        if l.startswith ('From '):
                            fh.write ('>')
                        fh.write (l)
                fh.write ('\n')
        return self

    def __init__ (self, source = sys.stdin, envelope_from = None, test_only = False,
                  cxform = header_cxform):
        """The constructor.  Arguments [all with defaults] are as follows:
        source [sys.stdin] - can be a file, a string, or any object with a __str__ method
        envelope_from [None] - overrides envelope sender address in the From line, if any
        test_only [False] - if True, all deliveries are stubbed out (but logging, locking still done)
        cxform [empty] - list of transofrms called on headers before matching
        """
        self._test_only = test_only
        self._delivered = False
        self._logged_headers = set ()
        for i in range_3:             # try 3 things
            try:
                if i == 0:
                    from_line, self._headers, self._body, self._tf, self._size =\
                        _read_message_from_fh (source)
                elif i == 1:
                    from_line, self._headers, self._body, self._tf, self._size =\
                        _read_message_from_string (source)
                else:
                    from_line, self._headers, self._body, self._tf, self._size =\
                        _read_message_from_string (str (source))
            except ( TypeError, AttributeError ):
                continue
            break
        else:                           # everything failed
            logging.getLogger ('sortmail').error ('cannot initialize message')
            raise ParseError
        if envelope_from is not None:
            self._envelope_from = envelope_from
        elif from_line is None:
            self._envelope_from = os.getenv ('LOGNAME', 'nobody') + '@localhost'
        else:
            self._envelope_from = from_line
        self.cxform = cxform

    @contextmanager
    def restoring_cxform(self, cxform):
        old_cxform = self.cxform
        self.cxform = cxform
        try:
            yield
        finally:
            self.cxform = old_cxform

    def body (self):
        """Return the message body as a single string.
        Note that using this method defeats the purpose of storing the body
        partly in a temporary file.  It should only be used when the body size
        is known not to be excessive, or else as a last resort.
        """
        body = ''.join (self._body)
        if self._tf is not None:
            self._tf.seek (0)
            body += self._tf.read ()
        return body

    def size (self):
        """Return the message size in bytes."""
        return self._size

    def header_glob (self, tag, pat = None, kind = ':matches'):
        """Return the list of headers matching tag (case-insensitive exact
        match) and pat (an exact, sub, or glob pattern for the header
        contents according to kind), as a list of pairs (i, mc) where i
        is an index into the header array, and mc is the header
        contents.
        """
        if kind != ':is' and kind != ':contains' and kind != ':matches':
            raise BadMatchKindError (kind)
        if pat is None:
            if kind == ':matches':
                pat = '*'
            else:
                pat = ''
        logger = logging.getLogger ('sortmail')
        retval = []
        for i in range (len (self._headers)):
            t, c = self._headers [i]
            cx = c
            for xf in self.cxform:
                cx = xf (cx)
            if t.lower () == tag:
                if ((kind == ':matches' and fnmatch (cx, pat))
                    or (kind == ':is' and cx == pat)
                    or (kind == ':contains' and cx.find (pat) != -1)):
                    retval.append ((i, cx))
                    line = '(header match) ' + t + ': ' + c.rstrip ()
                    if line not in self._logged_headers:
                        logger.info ('%s', line)
                        if logger.getEffectiveLevel () <= logging.INFO:
                            self._logged_headers.add (line)
        return retval

    def _header_match_common (self, re_tag, content_p):
        """Given a regexp object re_tag and a function (taking a string argument) content_p,
        return a list of all tuples (i, (mt, mc)) such that:
        i is an index of a header line
        mt is the match object from matching the tag of the i-th header line with re_tag
        mc is the result of calling content_p with the content of i-th header as argument,
        and mc is true under the usual boolean coercion rules.
        Each header's contents is transformed by calling all functions
        in self.cxform, feeding the result of each function to the
        next.
        """
        logger = logging.getLogger ('sortmail')
        retval = []
        for i in range (len (self._headers)):
            t, c = self._headers [i]
            mt = re_tag.match (t)
            cx = c
            for xf in self.cxform:
                cx = xf (cx)
            mc = content_p (cx)
            if mt is not None and mc:
                retval.append ((i, (mt, mc)))
                line = '(header match) ' + t + ': ' + c.rstrip ()
                if line not in self._logged_headers:
                    logger.info ('%s', line)
                    if logger.getEffectiveLevel () <= logging.INFO:
                        self._logged_headers.add (line)
        return retval

    def header_match (self, ptag, pcontent = r'.*', flags = re.S):
        """Given patterns (regexp strings) ptag and pcontent,
        return a list of all tuples (i, (mt, mc)) such that:
        i is an index of a header line
        mt is the match object from matching the tag of the i-th header line with ptag
        mc is the match object from matching the content of the i-th header line with pcontent
        Matching on the tag is done case-insensitively if ptag contains no ASCII uppercase
        characters.  Matching on the content is done according to flags.
        Since an empty list is considered false in tests, this method can and should be
        used to simply test for the presence of a header with a particular tag and a content
        matching a particular pattern, as in
        > if msg.header_match ('from', '.*spammer@spams\.r\.us\.com'):
        >     sys.exit (0)    # effectively delivers message to /dev/null
        """
        if _my_islower (ptag):
            ptag = r'(?is)' + ptag
        else:
            ptag = r'(?s)' + ptag
        re_tag = re.compile (ptag)
        re_content = re.compile (pcontent, flags)
        return self._header_match_common (re_tag, lambda c: re_content.match (c))

    def header_match_no_cxform (self, ptag, pcontent = r'.*', flags = re.S):
        """Just like header_match, but header_cxform is locally set to an empty list,
        preventing useless or harmful transformation functions from being called.
        """
        with self.restoring_cxform([]):
            return self.header_match (ptag, pcontent, flags)

    def get_header (self, index):
        """Return a pair (tag, content) - the item at index in the list of message headers."""
        return self._headers [index]

    def get_headers (self, indices):
        """Return a list of pairs (tag, content) - the items at indices in the list of message headers.
        This can be used to retrieve the matching headers after a successful header_match call.
        For example:
        > matches = msg.header_match ('received')
        > headers = msg.get_headers ([i for (i, _) in matches])
        """
        return [self._headers [i] for i in indices]

    def header_start (self, ptag, pcontent, flags = re.S):
        """This is like header_match, but only matches pcontent at the start
        of the header content string, with only white space possibly preceding.
        """
        return self.header_match (ptag, r'\s*' + pcontent, flags)

    def destination_match (self, pcontent, flags = re.S):
        """A specialized variant of header_match, where the tag pattern is taken to match all
        headers specifying the message destination.
        """
        re_content = re.compile (pcontent, flags)
        return self._header_match_common (_re_dest_tag, lambda c: re_content.match (c))

    def destination_word (self, word, flags = re.S):
        """A specialized variant of header_match, where the tag pattern is taken to match all
        headers specifying the message destination, and the content pattern is the literal
        word following a word boundary.  This is similar to the ^TO shortcut in procmail
        configuration files.
        """
        return self.destination_match (r'.*\b' + word, flags)

    def destination_address (self, address):
        """Like header_match, where the tag pattern is taken to match
        all headers specifying the message destination, but the content
        is searched for the specified address instead of matched with a
        regexp.  This is similar to the ^TO_ shortcut in procmail
        configuration files.
        """
        return self._header_match_common (_re_dest_tag, lambda c: _has_address (address, c))

    def list_match (self, address):
        """This matching method returns the list of matching headers for a mailing list
        message.  The return value is in the same format as for header_match.  The headers
        examined to check for a match are List-Post and List-Id.
        """
        qaddress = re.escape (address)
        return (self.destination_address (address)
                or self.header_start ('list-post', '<mailto:' + qaddress + '>', re.I|re.S)
                or self.header_start ('list-post', qaddress + r'\s*$', re.I|re.S)
                or self.header_match ('list-id', '.*<' + qaddress.translate (_atsign_to_dot_table) + '>', re.I|re.S))

    def newsgroup_match (self, newsgroup):
        """This matching method returns the list of matching headers for a Usenet
        newsgroup message.  The return value is in the same format as for header_match.
        """
        return self.header_match ('newsgroups', '(?:.*[ \t,])?(' + re.escape (newsgroup) + ')(?:[ ,].*)?$', re.I|re.S)

    def sender_match (self, pcontent, flags = re.S):
        """A specialized variant of header_match, where the tag pattern is taken to match all
        headers specifying the message source.
        """
        re_content = re.compile (pcontent, flags)
        return self._header_match_common (_re_sender_tag, lambda c: re_content.match (c))

    def add_header_at (self, tag, content, index):
        """Add a header at the specified index in the header array."""
        logger = logging.getLogger ('sortmail')
        logger.info ('adding %s header at %d', tag, index)
        self._headers [index : index] = [(tag, content)]
        self._size += len(tag) + 2 # colon, space
        self._size += len(content)
        return self

    def append_header (self, tag, content):
        """Append a header at the end of the header array."""
        logger = logging.getLogger ('sortmail')
        logger.info ('appending %s header', tag)
        self._headers.append ((tag, content))
        self._size += len(tag) + 2 # colon, space
        self._size += len(content)
        return self

    def append_header_if_absent (self, tag, content):
        """Append a header at the end of the header array
        unless a header with the same tag is already present.
        """
        matches = self.header_match_no_cxform (tag.translate (_ascii_lower_table))
        if matches == []:
            self.append_header (tag, content)
        return self

    def add_header_before (self, newtag, content, tag, index):
        """Add a header just before the specified occurrence of tag
        in the header array.  If there are not enough headers
        with the specified tag, do nothing.
        """
        matches = self.header_match_no_cxform (tag)
        if len (matches) > index:
            (i, _) = matches [index]
            self.add_header_at (newtag, content, i)
        return self

    def add_header_after (self, newtag, content, tag, index):
        """Add a header just after the specified occurrence of tag
        in the header array.  If there are not enough headers
        with the specified tag, do nothing.
        """
        matches = self.header_match_no_cxform (tag)
        if len (matches) > index:
            (i, _) = matches [index]
            self.add_header_at (newtag, content, i + 1)
        return self

    def delete_header_at (self, index):
        """Delete a header at the specified index in the header array."""
        logger = logging.getLogger ('sortmail')
        logger.info ('deleting header at %d', index)
        tag, content = self._headers[index]
        self._headers [index : index + 1] = []
        self._size -= len(tag) + 2
        self._size -= len(content)
        return self

    def delete_header_tag (self, tag, index):
        """Delete the specified occurrence of a header with the specified tag."""
        matches = self.header_match_no_cxform (tag)
        if len (matches) > index:
            (i, _) = matches [index]
            self.delete_header_at (i)
        return self

    def delete_header_tag_all (self, tag):
        """Delete all occurrences of a header with the specified tag."""
        matches = self.header_match_no_cxform (tag)
        i = 0
        for (j, _) in matches:
            self.delete_header_at (j - i)
            i += 1
        return self

    def replace_header_at (self, tag, content, index):
        """Replace a header at the specified index in the header array."""
        logger = logging.getLogger ('sortmail')
        logger.info ('replacing header at %d with %s', index, tag)
        otag, ocontent = self._headers[index]
        self._headers [index] = (tag, content)
        self._size += len(tag) - len(otag)
        self._size += len(content) - len(ocontent)
        return self

    def replace_header_tag (self, newtag, content, tag, index):
        """Replace the specified occurrence of a header with the specified tag."""
        matches = self.header_match_no_cxform (tag)
        if len (matches) > index:
            (i, _) = matches [index]
            self.replace_header_at (newtag, content, i)
        return self

    def transform_header_at (self, xform, index):
        """Transform a header at the specified index in the header array,
        by calling the function-like object xform with the (tag, content)
        pair as an argument, and replacing the header with the result.
        """
        logger = logging.getLogger ('sortmail')
        logger.info ('transforming header at %d', index)
        otag, ocontent = self._headers[index]
        tag, content = xform((otag, ocontent))
        self._headers [index] = (tag, content)
        self._size += len(tag) - len(otag)
        self._size += len(content) - len(ocontent)
        return self

    def transform_header_tag (self, xform, tag, index):
        """Transform the specified occurrence of a header with the specified tag,
        by calling the function-like object xform with the (tag, content)
        pair as an argument, and replacing the header with the result.
        """
        matches = self.header_match_no_cxform (tag)
        if len (matches) > index:
            (i, _) = matches [index]
            self.transform_header_at (xform, i)
        return self

    def rename_header_at (self, newtag, index):
        """Rename a header at the specified index in the header array."""
        logger = logging.getLogger ('sortmail')
        logger.info ('renaming header at %d to %s', index, newtag)
        (tag, content) = self._headers [index]
        self._headers [index] = (newtag, content)
        self._size += len(newtag) - len(tag)
        return self

    def rename_header_tag (self, newtag, tag, index):
        """Rename the specified occurrence of a header with the specified tag."""
        matches = self.header_match_no_cxform (tag)
        if len (matches) > index:
            (i, _) = matches [index]
            self.rename_header_at (newtag, i)
        return self

    def rename_header_tag_all (self, newtag, tag):
        """Rename all occurrences of a header with the specified tag."""
        matches = self.header_match_no_cxform (tag)
        for (i, _) in matches:
            self.rename_header_at (newtag, i)
        return self

    def append_header_and_rename (self, tag, content):
        """Append a header at the end of the header array,
        renaming all existing occurrences of the tag by adding the prefix
        'X-Original-'.
        """
        lctag = tag.translate (_ascii_lower_table)
        self.rename_header_tag_all ('X-Original-' + tag, lctag)
        self.append_header (tag, content)
        return self

    def uniquify_tag_first (self, tag):
        """Make the first occurrence of the specified header tag unique
        by deleting all subsequent occurrences of the same tag.
        """
        matches = self.header_match_no_cxform (tag)
        i = 0
        for (j, _) in matches [1 : ]:
            self.delete_header_at (j - i)
            i += 1
        return self

    def uniquify_tag_last (self, tag):
        """Make the last occurrence of the specified header tag unique
        by deleting all preceding occurrences of the same tag.
        """
        matches = self.header_match_no_cxform (tag)
        i = 0
        for (j, _) in matches [ : -1]:
            self.delete_header_at (j - i)
            i += 1
        return self

    def received_ip_generate (self):
        """A generator function yielding all IPv4 addresses found in
        Received headers.  Useful for passing them to RBL checking code.
        """
        received_3somes = self.header_match('received')
        for match_tuple in received_3somes:
            content_match = match_tuple[1][1]
            content_string = content_match.group(0)
            for match in _re_received_ip.finditer(content_string):
                yield map(match.group, range(2, 6))

    def filtermsg (self, argv):
        """Transform the message by passing it on standard input to
        an external program.  The message is reconstructed from
        the standard output of the program.
        """
        logger = logging.getLogger ('sortmail')
        logger.info ('(filter) %s', ' '.join (argv))
        pipe = subprocess.Popen (argv, stdin = PIPE, stdout = PIPE)
        with closing(pipe.stdout):
            try:
                with closing(pipe.stdin):
                    self._print_to_fh (pipe.stdin)
            except:
                logger.error ('cannot filter with %s', ' '.join (argv))
                raise
            from_line, self._headers, self._body, self._tf, self._size =\
                _read_message_from_fh (pipe.stdout)
            if from_line is not None:
                self._envelope_from = from_line
        status = pipe.wait ()
        if status > 0 or (status < 0 and -status != signal.SIGPIPE):
            logger.error ('filter subprocess exited with status %d', status)
            raise ProcessError (argv, status)
        return self

    def deliver (self, dests, dbfile = None, id_headers = _id_headers,
                 clean_period = _ONE_DAY, ttl = _ONE_WEEK):
        """Deliver the message to each destination in the argument list.  Each
        destination is an instance of MBox, MPipe or MForward class, or
        other object providing the same interface.  If the the dbfile
        optional argument is present, it is used as a filename of a
        shelf file used for eliminating duplicate messages.

        The clean_period and ttl arguments are now ignored.
        """
        logger = logging.getLogger ('sortmail')
        logger.debug ('(deliver) %d destinations', len (dests))
        if dbfile is None:
            for d in dests:
                d.accept(self)
        else:
            key = ''
            for h in id_headers:
                matches = self.header_match ('^' + h + '$')
                if matches == []:
                    continue
                (_, (mt, mc)) = matches [0]
                key += mt.group (0) + ': ' + mc.group (0)
            with unlocking(dbfile):
                with closing(shelve.open(dbfile)) as db:
                    logger.debug ('(dedupe) probing ID')
                    shelved = False
                    if key in db:
                        record = db[key]
                        v = record['version']
                        if _LAST_COMPAT <= v and v <= _MY_VERSION:
                            logger.info ('(dedupe) detected duplicate ID')
                            logger.debug('(dedupe) ' + str(record))
                            shelved = True
                    if not shelved:
                        for d in dests:
                            d.accept(self)
                        record = {'version': _MY_VERSION, 'time': time.time(),
                                  'host': gethostname()}
                        logger.debug ('(dedupe) recording new ID')
                        db[key] = record
        logger.debug ('(deliver) success')
        return self
