#!python
"""
VCF processing for wormtable. 

Implementation Note: We use bytes throughout the parsing process here for
a few reasons. Mostly, this is because it's much easier to deal with bytes
values within the C module, as we'd have to decode Unicode objects before 
getting string. At the same time, it's probably quite a bit more efficient 
to work with bytes directly, so we win both ways. It's a bit tedious making 
sure that all the literals have a 'b' in front of them, but worth the 
effort.
"""
from __future__ import print_function
from __future__ import division 

import os
import sys
import gzip
import shutil 
import argparse
import tempfile
import multiprocessing

import wormtable as wt

# VCF Fixed columns

CHROM_NAME = b"CHROM"
POS_NAME = b"POS"
ID_NAME = b"ID"
REF_NAME = b"REF"
ALT_NAME = b"ALT"
QUAL_NAME = b"QUAL"
FILTER_NAME = b"FILTER"
INFO_NAME = b"INFO"

VCF_FIXED_COLUMNS = [CHROM_NAME, POS_NAME, ID_NAME, REF_NAME, ALT_NAME, 
        QUAL_NAME, FILTER_NAME]

# TODO put in proper descriptions
CHROM_DESCRIPTION = b"CHROM"
POS_DESCRIPTION = b"POS"
ID_DESCRIPTION = b"ID"
REF_DESCRIPTION = b"REF"
ALT_DESCRIPTION = b"ALT"
QUAL_DESCRIPTION = b"QUAL"
FILTER_DESCRIPTION = b"FILTER"
INFO_DESCRIPTION = b"INFO"

# Special values in VCF
MISSING_VALUE = b"."

# Strings used in the header for identifiers
ID = b"ID"
INFO = b"INFO"
DESCRIPTION = b"Description"
NUMBER = b"Number"
TYPE = b"Type"
INTEGER = b"Integer"
FLOAT = b"Float"
FLAG = b"Flag"
CHARACTER = b"Character"
STRING = b"String"

class VCFReader(object):
    """
    A class for reading VCF files. 
    """
    def __init__(self, vcf_file):
        self.__genotypes = []
        if vcf_file.endswith(".gz"):
            self.__input_file = gzip.open(vcf_file, "rb")
            self.__progress_file = self.__input_file.fileobj
        else:
            self.__input_file = open(vcf_file, "rb")
            self.__progress_file = self.__input_file 
        statinfo = os.stat(vcf_file)
        self.__input_file_size = statinfo.st_size 
        self.__progress_update_rows = 2**32 
        self.__progress_monitor = None
        self.__truncate = False

    def set_progress(self, progress):
        """
        If progress is True turn on progress monitoring for this VCF reader.
        """
        if progress:
            self.__progress_monitor = wt.ProgressMonitor(self.__input_file_size, 
                    "bytes")
            # FIXME This doesn't work
            self.__progress_update_rows = 100
            if self.__input_file_size > 2**30:
                progress_rows = 10000
        
    def set_truncate_REF(self, truncate):
        """
        If true, truncate REF columns to be no more than 254 characters long.
        This is a temporary workaround until more sophisticated truncation 
        across all columns is implemented.
        """
        self.__truncate = truncate

    def update_progress(self):
        """
        Reads the position we are at in the underlying file and uses this to 
        update the progress bar, if used.
        """
        if self.__progress_monitor is not None:
            t = self.__progress_file.tell() 
            self.__progress_monitor.update(t)

    def finish_progress(self):
        """
        Finishes up the progress monitor, if in use.
        """ 
        if self.__progress_monitor is not None:
            self.update_progress()
            self.__progress_monitor.finish()

    def close(self):
        """
        Closes any open files on this Reader.
        """
        self.__input_file.close()
        self.__progress_file.close()

    def parse_version(self, s):
        """
        Parse the VCF version number from the specified string.
        """
        self._version = -1.0
        tokens = s.split(b"v")
        if len(tokens) == 2:
            self._version = float(tokens[1])

    def parse_header_line(self, s):
        """
        Processes the specified header string to get the genotype labels.
        """
        self.__genotypes = s.split()[9:]

    def add_column(self, table, prefix, line):
        """
        Adds a VCF column using the specified metadata line with the specified 
        name prefix to the specified table.
        """
        d = {}
        s = line[line.find(b"<") + 1: line.find(b">")]
        for j in range(3):
            k = s.find(b",")
            tokens = s[:k].split(b"=")
            s = s[k + 1:]
            d[tokens[0]] = tokens[1]
        tokens = s.split(b"=", 1)
        d[tokens[0]] = tokens[1]
        name = d[ID]
        description = d[DESCRIPTION].strip(b"\"")
        number = d[NUMBER]
        num_elements = wt.WT_VAR_1 
        try:
            # If we can parse it into a number, do so. If this fails than use
            # a variable number of elements.
            num_elements = int(number)    
        except ValueError as v:
            pass
        # We can also have negative num_elements to indicate variable column
        if num_elements < 0:
            num_elements = wt.WT_VAR_1 
        st = d[TYPE]
        if st == INTEGER:
            element_type = wt.WT_INT
            element_size = 4
        elif st == FLOAT: 
            element_type = wt.WT_FLOAT
            element_size = 4
        elif st == FLAG: 
            element_type = wt.WT_INT
            element_size = 1
        elif st == CHARACTER: 
            element_type = wt.WT_CHAR
            element_size = 1
        elif st == STRING: 
            num_elements = wt.WT_VAR_1 
            element_type = wt.WT_CHAR
            element_size = 1
        else:
            raise ValueError("Unknown VCF type:", st)
        
        table.add_column(prefix + b"_" + name,  description, element_type, 
                element_size, num_elements)

    def generate_schema(self, table):
        """
        Reads the header from the specified VCF file and returns a Table 
        with the correct columns.
        """
        f = self.__input_file
        s = f.readline()
        info_descriptions = []
        genotype_descriptions = []
        self.parse_version(s)
        if self._version < 4.0:
            raise ValueError("VCF versions < 4.0 not supported")
        while s.startswith(b"##"):
            # skip FILTER values 
            if s.startswith(b"##INFO"):
                info_descriptions.append(s)
            elif s.startswith(b"##FORMAT"):
                genotype_descriptions.append(s)
            s = f.readline()
        self.parse_header_line(s)
        # Add the fixed columns
        table.add_id_column(5)
        table.add_char_column(CHROM_NAME, CHROM_DESCRIPTION)
        table.add_uint_column(POS_NAME, POS_DESCRIPTION, 5)
        table.add_char_column(ID_NAME, ID_DESCRIPTION)
        table.add_char_column(REF_NAME, REF_DESCRIPTION)
        table.add_char_column(ALT_NAME, ALT_DESCRIPTION)
        table.add_float_column(QUAL_NAME, QUAL_DESCRIPTION, 4)
        table.add_char_column(FILTER_NAME, FILTER_DESCRIPTION)
        for s in info_descriptions:
            self.add_column(table, INFO_NAME, s)
        for genotype in self.__genotypes:
            for s in genotype_descriptions: 
                self.add_column(table, genotype, s) 

    def read_header(self):
        """
        Reads the VCF header and skips to the first line of the content.
        """
        f = self.__input_file
        s = f.readline()
        while s.startswith(b"##"):
            s = f.readline()
        self.parse_header_line(s) 
        
    def rows(self, table_columns):
        """
        Returns an iterator over the rows in this VCF file. Each row is a 
        dictionary mapping column positions to their encoded string values.
        """
        self.read_header()
        # First we construct the mappings from the various parts of the 
        # VCF row to the corresponding column index in the wormtable
        num_columns = len(table_columns) 
        all_fixed_columns = VCF_FIXED_COLUMNS 
        fixed_columns = []
        # weed out the columns that are not in the table
        for j in range(len(all_fixed_columns)):
            name = all_fixed_columns[j]
            if name in table_columns:
                fixed_columns.append((j, table_columns[name]))
        info_columns = {}
        genotype_columns = [{} for g in self.__genotypes]
        for k, v in table_columns.items(): 
            if b"_" in k and v != 0:
                split = k.split(b"_")
                if split[0] == INFO:
                    name = split[1]
                    info_columns[name] = v 
                else:
                    g = b"_".join(split[:-1])
                    name = split[-1]
                    index = self.__genotypes.index(g)
                    genotype_columns[index][name] = v 
        ref_index = 3
        # Now we are ready to process the file.
        num_rows = 0
        for s in self.__input_file:
            row = [None for j in range(num_columns)] 
            l = s.split()
            # Read in the fixed columns
            for vcf_index, wt_index in fixed_columns:
                if l[vcf_index] != MISSING_VALUE:
                    row[wt_index] = l[vcf_index]
                    if vcf_index == ref_index and self.__truncate:
                        # truncate the REF column if necessary; this is a 
                        # temporary workaround until more sophisticated 
                        # truncation on a per column basis is implemented.
                        if len(l[vcf_index]) > 255:
                            row[wt_index] = l[vcf_index][:254] + b'+'
            # Now process the info columns.
            for mapping in l[7].split(b";"):
                tokens = mapping.split(b"=")
                name = tokens[0]
                if name in info_columns:
                    col = info_columns[name]
                    if len(tokens) == 2:
                        row[col] = tokens[1]
                    else:
                        # This is a Flag column.
                        row[col] = b"1"
            # Process the genotype columns. 
            j = 0
            fmt = l[8].split(b":")
            for genotype_values in l[9:]:
                tokens = genotype_values.split(b":")
                if len(tokens) == len(fmt):
                    for k in range(len(fmt)):
                        if fmt[k] in genotype_columns[j]:
                            col = genotype_columns[j][fmt[k]]
                            row[col] = tokens[k]
                elif len(tokens) > 1:
                    # We can treat a genotype value on its own as missing values.
                    # We can have skipped columns at the end though, which we 
                    # should deal with properly. So, put in a loud complaint 
                    # here and fix later.
                    print("PARSING CORNER CASE NOT HANDLED!!! FIXME!!!!")
                j += 1
            yield row
            num_rows += 1
            if num_rows % self.__progress_update_rows == 0:
                self.update_progress()
        self.finish_progress()

class VCFWriter(object):
    """
    Class that writes VCF rows to a wormtable.
    """
    def __init__(self, table):
        self.__table = table
        self.__table.read_metadata(self.__table.get_metadata_path())
        self.__table.open("w")
    
    def append(self, row):
        self.__table.append_encoded(row)

    def close(self):
        self.__table.close()
    
def writer_process(table, row_queue, error_event):
    """
    Reads records from the specified queue and writes them to the specified
    table. 
    """
    try:
        table.read_metadata(table.get_metadata_path())
        table.open("w") 
        r = row_queue.get()
        while r is not None:
            table.append_encoded(r)
            r = row_queue.get()
        table.close()
    except Exception as e:
        print("Exception occured", e)
        error_event.set()

class VCFWriterProxy(object):
    """
    Proxy for a VCFWriter object in a subprocess.
    """
    def __init__(self, runner):
        self.__program_runner = runner
        self.__queue = multiprocessing.Queue(runner.get_queue_size())
        self.__error_event = multiprocessing.Event() 
        self.__writer_process = multiprocessing.Process(target=writer_process,
                args=(runner.get_table(), self.__queue, self.__error_event))
        self.__writer_process.start() 

    def append(self, row):  
        """
        Appends the specified row to the queue.
        """
        if self.__error_event.is_set():
            self.__queue.cancel_join_thread()
            self.__queue.close()
            self.__queue = None
            self.__writer_process.join()
            self.__program_runner.error("Subprocess failed")
        else:
            self.__queue.put(row)

    def close(self):
        if self.__queue is not None:
            self.__queue.put(None)


class ProgramRunner(object):
    """
    Class responsible for running the vcf2wt program.
    """
    def __init__(self, args):
        self.__source = args.SOURCE
        self.__destination = args.DEST 
        self.__cache_size = args.cache_size
        self.__force = args.force
        self.__generate_schema = args.generate_schema
        self.__progress = args.progress
        self.__schema = args.schema
        self.__truncate = args.truncate
        self.__queue_size = args.queue_size
        self.__tmp_dirs = []
        self.__tmp_files = []
        self.__table = None
        self.__column_map = None
        self.__reader = None
        self.__writer = None
       
    def get_table(self):
        return self.__table

    def get_queue_size(self):
        return self.__queue_size
        
    def generate_schema(self):
        """
        Reads the header of the input VCF and generates a schema file.
        """
        fd, schema_file = tempfile.mkstemp(suffix=".xml", prefix="vcf2wt_")
        self.__tmp_files.append(schema_file)
        os.close(fd)
        tmpdir = tempfile.mkdtemp(suffix=".wt", prefix="vcf2wt_")
        self.__tmp_dirs.append(tmpdir)
        table = wt.Table(tmpdir)
        reader = VCFReader(self.__source)
        reader.generate_schema(table)
        table.write_metadata(schema_file)
        reader.close()
        self.__schema = schema_file

    def initialise_writer(self):
        """
        Initialises the writer object and starts the writer process if it 
        is being used.
        """
        if self.__queue_size == 0:
            self.__writer = VCFWriter(self.__table)
        else:
            self.__writer = VCFWriterProxy(self) 

    def create_table(self):
        """
        Creates the table and reads the column information for the VCF reader.
        """
        os.mkdir(self.__destination)
        self.__table = wt.Table(self.__destination)
        self.__table.read_metadata(self.__schema)
        self.__table.set_cache_size(self.__cache_size)
        self.__table.open("w")
        self.__column_map = {}
        for c in self.__table.columns():
            self.__column_map[c.get_name().encode()] = c.get_position()
        self.__table.close()

    def write_table(self):
        """
        Writes the table, assuming that we have created a directory with
        a table ready for writing.
        """
        self.__reader = VCFReader(self.__source)
        self.__reader.set_progress(self.__progress)
        self.__reader.set_truncate_REF(self.__truncate)
        self.initialise_writer()
        for r in self.__reader.rows(self.__column_map):
            self.__writer.append(r) 
        self.__reader.close()
        self.__reader = None
        self.__writer.close()
        self.__writer = None

    def run(self):
        """
        Top level entry point.
        """ 
        if self.__schema is None:
            self.generate_schema()
        if os.path.exists(self.__destination):
            if self.__force:
                if os.path.isdir(self.__destination):
                    shutil.rmtree(self.__destination)
                else:
                    os.unlink(self.__destination)
            else:
                s = "'{0}' exists; use -f to overwrite".format(self.__destination)
                self.error(s)
        if self.__generate_schema:
            # copy the schema and we're done.
            shutil.copyfile(self.__schema, self.__destination) 
        else:
            self.create_table()
            self.write_table()

    def error(self, s):
        """
        Raises and error and exits.
        """
        print("Error:", s)
        sys.exit(0)

    def cleanup(self):
        """
        Cleans up any temporary files shuts down any running processes.
        """
        for f in self.__tmp_dirs:
            shutil.rmtree(f)
        for f in self.__tmp_files:
            os.unlink(f)
        if self.__reader is not None:
            self.__reader.close()
        if self.__writer is not None:
            self.__writer.close()

def main():
    prog_description = "Convert VCF file to Wormtable format."
    parser = argparse.ArgumentParser(description=prog_description) 
    parser.add_argument("SOURCE", 
        help="VCF file to convert")   
    parser.add_argument("DEST", 
        help="""Output wormtable home directory, or schema file 
            if we are generating a candidate schema using the 
            --generate-schema option""")   
    parser.add_argument("--progress", "-p", action="store_true", 
        default=False,
        help="Show progress monitor")   
    parser.add_argument("--force", "-f", action="store_true", default=False,
        help="Force over-writing of existing wormtable")   
    parser.add_argument("--truncate", "-t", action="store_true", default=False,
        help="""Truncate values that are too large for a column and store
            the maximum the column will allow. Currently this option 
            only supports truncating REF column values more than 255 
            characters long. REF values are truncated to 254 characters
            and suffixed with a '+' to indicate that truncation has 
            occured""")   
    parser.add_argument("--queue-size", "-q", default=65536, type=int,
        help="""queue size for interprocess communication; set to 0 to 
            disable multiprocessing (recommended for single core systems)""")
    parser.add_argument("--cache-size", "-c", default="64M",
        help="cache size in bytes; suffixes K, M and G also supported.")   
    g = parser.add_mutually_exclusive_group()
    g.add_argument("--generate-schema", "-g", action="store_true", 
        default=False,
        help="""Generate a schema for the source VCF file and 
            write to DEST. Only reads the header of the VCF file.""")   
    g.add_argument("--schema", "-s", default=None,
        help="""Use schema from the file SCHEMA rather than default 
                generated schema""")
    args = parser.parse_args()
    runner = ProgramRunner(args) 
    try:
        runner.run()
    finally:
        runner.cleanup()

if __name__ == "__main__":
    main()

