#!/usr/bin/env python

# Copyright (c) 2012 Leif Johnson <leif@leifjohnson.net>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import bz2
import gzip
import logging
import multiprocessing as mp
import optparse
import re
import sys

import lmj.tagger

FLAGS = optparse.OptionParser()

FLAGS.add_option('-b', '--beam', type=int, metavar='N',
                 help='decode sequences using a beam search of width N (None)')
FLAGS.add_option('-c', '--context', type=int, default=3, metavar='N',
                 help='extract features using a context window of N words (3)')
FLAGS.add_option('-d', '--db', default='tagger', metavar='PREFIX',
                 help='store weights in sqlite3 files named PREFIX-*.db (tagger)')
FLAGS.add_option('-e', '--encoding', default='utf8', metavar='E',
                 help='read and write text using encoding E (utf8)')
FLAGS.add_option('-p', '--concurrency', type=int, metavar='N',
                 help='run concurrently with N processes (None)')
FLAGS.add_option('-t', '--tags', default='tags.txt', metavar='FILE',
                 help='read the tag dictionary from FILE (tags.txt)')
FLAGS.add_option('-v', '--verbose', action='store_true',
                 help='print out more debugging information')

g = optparse.OptionGroup(FLAGS, 'Modes')
g.add_option('', '--train', metavar='FILE',
             help='train on labeled sequences from FILE (-)')
g.add_option('', '--test', metavar='FILE',
             help='test on labeled sequences from FILE')
g.add_option('', '--label', metavar='FILE',
             help='print labels for sequences from FILE')
FLAGS.add_option_group(g)


def parse_labeled_lines(handle, encoding='utf8'):
    for line in handle:
        parts = [l.split('/') for l in line.decode(encoding).strip().split()]
        yield [w for w, t in parts], [t for w, t in parts]


def parse_unlabeled_lines(handle, encoding='utf8'):
    for line in handle:
        yield line.decode(encoding).strip().split()


def open_resource(name):
    if name == '-':
        return sys.stdin
    if name.endswith('.gz'):
        return gzip.open(name)
    if name.endswith('.bz2'):
        return bz2.BZ2File(name)
    return open(name)


def create_tagger(opts):
    def create_storage(table):
        return lmj.tagger.storage.Sqlite(
            '%s-%s.db' % (opts.db, table),
            encoding=opts.encoding)
    tags = sorted(set(filter(None, open(opts.tags).read().split())))
    return lmj.tagger.Discriminative(
        opts.context, opts.beam, tags, create_storage)


def neps(x):
    return x != '_'


def average(q, opts):
    T = create_tagger(opts)

    while True:
        go = q.get()
        if go is None:
            break
        T.weights.average(T.edge)


def train(q, aq, opts):
    T = create_tagger(opts)

    accuracy = []

    while True:
        item = q.get()
        if item is None:
            break

        words, tags = item
        best = T.train(words, tags)
        logging.debug('tags %s, best %s', '-'.join(tags), '-'.join(best))

        # signal to the averaging process to average the current weights.
        if aq.empty():
            aq.put(True)

        correct = 0.
        total = 0.
        for b, t in zip(filter(neps, best), filter(neps, tags)):
            correct += int(b == t)
            total += 1
        accuracy.append(correct / (total or 1))

        if len(accuracy) == 100:
            logging.info('training accuracy: %.2f',
                         100. * sum(accuracy) / len(accuracy))
            accuracy = []


def test(q, _, opts):
    T = create_tagger(opts)

    # counts for individual tag matches.
    all_tag = 0
    phone_tag = 0
    total_tag = 0

    # counts for complete sequence matches.
    all_seq = 0
    phone_seq = 0
    total_seq = 0

    idxs = {t: i for i, t in enumerate(T.tags)}
    confusion = [[0] * len(idxs) for _ in idxs]

    def ndig(x):
        return re.sub(r'\d+', '', x)

    while True:
        item = q.get()
        if item is None:
            break

        words, tags = item
        best = T.label(words)
        logging.debug('tags %s, best %s', '-'.join(tags), '-'.join(best))

        for b, t in zip(best, tags):
            confusion[idxs[t]][idxs[b]] += 1

        aseq = True
        pseq = True
        for b, t in zip(filter(neps, best), filter(neps, tags)):
            if b == t:
                all_tag += 1
            else:
                aseq = False
            if ndig(b) == ndig(t):
                phone_tag += 1
            else:
                pseq = False
            total_tag += 1
        all_seq += int(aseq)
        phone_seq += int(pseq)
        total_seq += 1

    logging.error('tags\n%s', ' '.join(T.tags))
    logging.error('confusion\n%s',
                  '\n'.join(' '.join(str(n) for n in cs) for cs in confusion))
    logging.error('sequence accuracy of %d: all %d - %.2f, phone %d - %.2f',
                  total_seq,
                  all_seq, 100. * all_seq / total_seq,
                  phone_seq, 100. * phone_seq / total_seq)
    logging.error('tag accuracy of %d: all %d - %.2f, phone %d - %.2f',
                  total_tag,
                  all_tag, 100. * all_tag / total_tag,
                  phone_tag, 100. * phone_tag / total_tag)


def label(q, _, opts):
    T = create_tagger(opts)
    while True:
        words = q.get()
        if words is None:
            break
        print ' '.join(words), '==>', ' '.join(T.label(words))


def main(opts, args):
    q = mp.Queue()
    aq = mp.Queue()
    averager = None

    read = None
    consume = None
    logname = None
    if opts.train is not None:
        read = lambda: parse_labeled_lines(open_resource(opts.train))
        consume = train
        logname = 'train'

        # start up a persistent process to average model weights.
        averager = mp.Process(target=average, args=(aq, opts))
        averager.start()

    if opts.test is not None:
        read = lambda: parse_labeled_lines(open_resource(opts.test))
        consume = test
        logname = 'test'

    if opts.label is not None:
        read = lambda: parse_unlabeled_lines(open_resource(opts.label))
        consume = label
        logname = 'label'

    logging.basicConfig(
        stream=open('%s-%s.log' % (opts.db, logname), 'a'),
        format='%(levelname).1s %(asctime)s %(message)s',
        level=opts.verbose and logging.DEBUG or logging.INFO)

    if opts.concurrency is None:
        # process data serially if no concurrency flag is given. :(
        for item in read():
            q.put(item)
        q.put(None)
        consume(q, aq, opts)

    else:
        # process data in parallel !
        workers = [mp.Process(target=consume, args=(q, aq, opts))
                   for _ in range(opts.concurrency)]
        [w.start() for w in workers]
        for item in read():
            q.put(item)
        [q.put(None) for w in workers]
        [w.join() for w in workers]

    # shut down the weight-averaging process.
    aq.put(None)
    if averager is not None:
        averager.join()


if __name__ == '__main__':
    main(*FLAGS.parse_args())
