# Copyright (c) 2012 Leif Johnson <leif@leifjohnson.net>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import logging
import random
import time

import features
import storage

BEGIN = '<'
END = '>'


class Discriminative(object):
    '''A tagger determines the best labels for a sequence of observations.

    This Tagger class implements a parametric, supervised model that learns
    weights for tagging decisions based on labeled training data. The model is
    described in Michael Collins' 2002 EMNLP paper, "Discriminative Training
    Methods for Hidden Markov Models: Theory and Experiments with Perceptron
    Algorithms".
    '''

    def __init__(self, context, beam, tags, create_storage, extract_features=None):
        '''Initialize this tagger with some model parameters.

        context: Extract features using a window of up to this far from the
          central word being tagged.
        beam: If None, decode using the full Viterbi algorithm. If this is a
          positive integer, use an approximating beam search with this width.
        tags: A set of tags to use when decoding.
        create_storage: A callable that takes one string argument (the name of
          the set of weights to be stored) and returns a storage.Weights object.
        extract_features: A callable that takes a complete word sequence, a tag
          sequence being constructed, and a context width -- and generates a
          sequence of string features.
        '''
        self.context = context
        self.beam = beam
        self.tags = tags
        self.extract = extract_features or features.defaults

        self.edge = create_storage('edge')
        self.weights = create_storage('weights')

        # make sure there's an iteration counter for this model.
        self.weights.update([(storage.ITERATION, 0)])

    def beam_decode(self, words, weights):
        '''
        Generate a tag hypotheses for the given sequence of words. If averaged
        is true, use the averaged model weights for scoring features ;
        otherwise, use the current "edge" weights.

        This decoder uses a beam search, which is a fast approximation to the
        full Viterbi decoder. A beam search maintains a "beam" (an array) of tag
        hypotheses as it decodes each word ; at each decoding step, each
        hypothesis in the beam is expanded with all possible tags, and then the
        resulting array of expanded hypotheses is sorted by decreasing score and
        then chopped off at the beam width.
        '''
        begin = [BEGIN for _ in range(self.context)]
        words_ = begin + words + [END for _ in range(1 + self.context)]
        hypotheses = [0, begin, features.Counter()]
        for index in range(self.context, self.context + len(words) + 1):
            hypotheses_ = []
            tags = words_[index] == END and [END] or self.tags
            for score, path, feats in hypotheses:
                for tag in tags:
                    fs = features.Counter()
                    fs.count(self.extract(words_, path + [tag], self.context))
                    s = weights.sum(fs) or random.random()
                    hypotheses_.append((score + s, path + [tag], feats + fs))
            hypotheses = sorted(hypotheses_, reverse=True)[:self.beam]
        return hypotheses[0]

    def viterbi_decode(self, words, weights):
        '''Decode a sequence of words using the Viterbi algorithm.

        This is often slower than the beam decoder (depending on the beam width),
        but yields the optimal tag sequence, under the assumption that tags are
        chosen from left to right and follow the Markov assumption.
        '''
        begin = [BEGIN for _ in range(self.context)]
        words_ = begin + words + [END for _ in range(1 + self.context)]
        scores, paths, feats = {BEGIN: 0}, {BEGIN: begin}, {BEGIN: features.Counter()}
        for index in range(self.context, self.context + len(words) + 1):
            scores_, paths_, feats_ = {}, {}, {}
            tags = words_[index] == END and [END] or self.tags
            for tag in tags:
                scores_[tag], paths_[tag], feats_[tag] = -1e100, None, None
                for source, path in paths.iteritems():
                    fs = features.Counter()
                    fs.count(self.extract(words_, path + [tag], self.context))
                    s = scores[source] + (weights.sum(fs) or random.random())
                    if s > scores_[tag]:
                        scores_[tag] = s
                        paths_[tag] = path + [tag]
                        feats_[tag] = feats[source] + fs
            scores = scores_
            paths = paths_
            feats = feats_
        return scores[END], paths[END], feats[END]

    def decode(self, words, weights):
        '''Send a sequence of words to the appropriate decoder.'''
        start = time.time()
        try:
            if self.beam is None:
                return self.viterbi_decode(words, weights)
            return self.beam_decode(words, weights)
        finally:
            logging.debug('decoded %s in %dms',
                          '-'.join(words), 1000 * (time.time() - start))

    def label(self, words):
        '''Get the best tag sequence for the given word sequence.'''
        _, best, _ = self.decode(words, self.weights)
        return best[self.context:-1]

    def train(self, words, tags):
        '''Train our model on a set of words and a correct set of tags.

        This method uses the current model weights to generate a tag hypothesis,
        and then updates the weights for features occurring only in the true tag
        sequence or only in the hypothesis.

        Returns the best tagging hypothesis for the given words.
        '''
        score, best, feats = self.decode(words, self.edge)
        best = best[self.context:-1]
        if not all(h == t for h, t in zip(best, tags)):
            begin = [BEGIN for _ in range(self.context)]
            words_ = begin + words + [END for _ in range(1 + self.context)]
            tags_ = begin
            true_feats = features.Counter()
            for t in tags + [END]:
                tags_.append(t)
                true_feats.count(self.extract(words_, tags_, self.context))
            self.edge.update(true_feats - feats)
        return best
