#!/usr/bin/env python
# -*- python -*-
from __future__ import division

__version__ = "$Revision: 1.5 $"

# Copyright 2006-2007 Michael M. Hoffman <hoffman+software@ebi.ac.uk>

import itertools
import sys

import tabdelim
from textinput import open

def get_key(row, col_indexes):
    return tuple(row[col_index] for col_index in col_indexes)

def read_lookup(iterable, col_indexes):
    res = {}

    for row in iterable:
        try:
            # exclude the joined columns
            key = get_key(row, col_indexes)

            res[key] = [row[col_index] for col_index in xrange(len(row))
                        if col_index not in col_indexes]
        except IndexError:
            pass

    return res

def write_joined(readerwriter, col_indexes, lookup):
    # duplication with two loops to save time
    # use None because options.empty could be ""
    if options.empty is None:
        for row in readerwriter:
            try:
                lookup_row = lookup[get_key(row, col_indexes)]
                row.extend(lookup_row)
            except KeyError:
                pass
    else:
        empty_text = options.empty
        len_longest_row = max(itertools.imap(len, lookup.itervalues()))
        for row in readerwriter:
            try:
                lookup_row = lookup[get_key(row, col_indexes)]
                len_difference = len_longest_row - len(lookup_row)
                if len_difference > 0:
                    lookup_row.extend([empty_text] * len_difference)
                row.extend(lookup_row)
            except KeyError:
                row.extend([empty_text] * len_longest_row)

def parse_fields(text):
    """
    >>> parse_fields("37")
    [36]
    >>> parse_fields("40,37")
    [36, 39]
    """
    return sorted(int(num)-1 for num in text.split(","))

def parse_options(args):
    from optparse import OptionParser

    global options

    usage = "%prog [OPTION]... MAIN-TABLE LOOKUP-TABLE"
    version = "%%prog %s" % __version__
    parser = OptionParser(usage=usage, version=version)
    parser.add_option("-e", "--empty", metavar="EMPTY",
                      help="replace missing input fields with EMPTY"),
    parser.add_option("-1", "--main-fields", default="1",
                      metavar="FIELDS",
                      help="join on these FIELDS of MAIN-TABLE"),
    parser.add_option("-2", "--lookup-fields", default="1",
                      metavar="FIELDS",
                      help="join on these FIELDS of LOOKUP-TABLE"),

    options, args = parser.parse_args(args)

    return args

def innerjoin(main_filename, lookup_filename):
    lookup_reader = tabdelim.ListReader(open(lookup_filename))
    lookup_indexes = parse_fields(options.lookup_fields)
    lookup = read_lookup(lookup_reader, lookup_indexes)

    main_readerwriter = tabdelim.ListReaderWriter(open(main_filename))
    main_indexes = parse_fields(options.main_fields)
    write_joined(main_readerwriter, main_indexes, lookup)

def main(args):
    args = parse_options(args)
    assert len(args) == 2

    return innerjoin(*args)

def _test(*args, **kwargs):
    import doctest
    doctest.testmod(sys.modules[__name__], *args, **kwargs)

if __name__ == "__main__":
    if __debug__:
        _test()
    sys.exit(main(sys.argv[1:]))
