#!/usr/bin/env python
# coding: utf-8

"""
Converts a phrasetable to a csv file

Usage:
    yalign-phrasetable-csv [options] <input_file> <output_csv_file>

Options:
  -h --help        Show this screen.
"""

import os
import csv
import gzip
import codecs
import tempfile
import subprocess
from docopt import docopt


def _open_phrasetable(filepath):
    """
    Opens a phrasetable file using the
    necesary method, either gzip or plain open.
    """

    if filepath.endswith(".gz"):
        return gzip.open(filepath)
    else:
        return codecs.open(filepath, "r", encoding="utf-8")


def _pre_filter(filepath):
    """
    Uses egrep command to pre-filter things
    """

    _, output_filepath = tempfile.mkstemp()
    _, error_filepath = tempfile.mkstemp()
    outfile = open(output_filepath, "w")
    errfile = open(error_filepath, "w")

    if filepath.endswith(".gz"):
        read_cmdline = "gzip -cd {filepath}"
    else:
        read_cmdline = "cat {filepath}"
    read_cmdline = read_cmdline.format(filepath=os.path.abspath(filepath))
    filter_cmdline = "grep -E '^\S+\s\|\|\|\s\S+\s\|\|\|'"
    cmdline = "{read} | {filter}".format(read=read_cmdline,
                                         filter=filter_cmdline)
    status = subprocess.call(cmdline,
                             stdout=outfile,
                             stderr=errfile,
                             shell=True)

    if status != 0:
        message = "Error precompiling: program returned {}.\n" \
                  "Check error output at: {}"
        raise Exception(message.format(status, error_filepath))
    return output_filepath


def filter_phrasetable(in_filepath):
    """
    Given a phrasetable file returns an iterator over the
    important entries on the table.

    The unnecessary entries are the ones with:
        * more than a 1-gram
        * low word count
    """

    in_filepath = _pre_filter(in_filepath)

    with _open_phrasetable(in_filepath) as filehandler:
        for line in filehandler:
            fields = line.split("|||")
            src = fields[0].strip().lower()
            tgt = fields[1].strip().lower()

            if len(src.split()) > 1 or len(tgt.split()) > 1:
                continue

            counts = fields[4].split()
            tgt_count = int(float(counts[0]))
            src_count = int(float(counts[1]))

            count_limit = 1
            if src_count < count_limit or tgt_count < count_limit:
                continue

            probs = fields[2].split()
            inverse_prob = float(probs[0])
            direct_prob = float(probs[2])
            # The resulting prob its a combination of both direct
            # and inverse probability
            prob = 0.5 * inverse_prob + 0.5 * direct_prob

            if prob < 0.001:
                continue

            yield src, tgt, prob


def save_csv_file(translations, outfile):
    """
    Given a iterator that returns (src, tgt, prob) it saves it
    into a csv file
    """

    writer = csv.writer(open(outfile, "w"))
    for src, tgt, prob in translations:
        writer.writerow([src.encode("utf-8"), tgt.encode("utf-8"), prob])


if __name__ == "__main__":
    args = docopt(__doc__)

    input_filepath = args["<input_file>"]
    output_filepath = args["<output_csv_file>"]

    try:
        translations_iterator = filter_phrasetable(input_filepath)
        save_csv_file(translations_iterator, output_filepath)
    except Exception as error:
        exit("Error: {}".format(error))
