#!/usr/bin/env python
#
# Copyright John Reid 2012, 2013
#

"""
Uses STEME to scan sequences for PWMs.
"""

import logging
logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s - %(levelname)s - %(message)s")
from optparse import OptionParser
from cookbook.lru_cache import lru_cache
from cookbook.script_basics import log_options
from xml.etree.ElementTree import ElementTree, Element, SubElement, tostring
from itertools import imap
from functools import partial
import stempy
import stempy.meme_parse
import numpy
import sys
import os
import time

try:
    from xml.etree.ElementTree import register_namespace
except ImportError:
    logging.warning(
        'Can not register CISML namespaces. '
        'Could use a newer version of xml.etree.')


def output_file(filename):
    return os.path.join(options.output_dir, filename)


class BedOutput(object):
    """Outputs scan results to BED files.
    """
    def __init__(self, out_filename, ids):
        self.out_filename = out_filename
        self.ids = ids

        logging.info('BED output will be written to: %s', self.out_filename)
        self.out = open(self.out_filename, 'w')

    def handle_motif(self, motif):
        self.motif = motif

    def handle_instance(self, info):
        print >>self.out, '%s\t%d\t%d\t%s\t%f\t%s' % (
            self.ids[info.seq], info.pos, info.pos + len(info.wmer),
            self.motif.name, 1000 * info.instance.Z,
            info.instance.rev_comp and '-' or '+',
        )

    def finalise(self):
        logging.info('Closing BED output: %s', self.out_filename)
        self.out.close()
        del self.out


class CsvOutput(object):
    """Outputs scan results to CSV files.
    """

    def __init__(self, scan_out_filename, seqs_out_filename, data, ids):
        self.scan_out_filename = scan_out_filename
        self.seqs_out_filename = seqs_out_filename
        self.data = data
        self.ids = ids

        logging.info('Matches will be written to: %s', self.scan_out_filename)
        self.out = open(self.scan_out_filename, 'w')
        print >>self.out, '# Motif,W-mer,Seq,Position,Strand,Z,Score,p-value'

        logging.info('Writing sequence lengths to: %s', self.seqs_out_filename)
        with open(seqs_out_filename, 'w') as f:
            print >> f, '# Length,ID'
            for i in xrange(data.num_sequences):
                print >> f, '%d,%s' % (data.seq_length(i), ids[i])

    def handle_motif(self, motif):
        self.motif = motif

    def handle_instance(self, info):
        print >>self.out, '%s,%s,%d,%d,%s,%f,%s,%s' % (
            self.motif.name, info.wmer, info.seq,
            info.pos, info.instance.rev_comp and '-' or '+', info.instance.Z,
            info.score, info.pvalue
        )

    def finalise(self):
        logging.info('Closing CSV output: %s', self.scan_out_filename)
        self.out.close()
        del self.out


class CISMLOutput(object):
    """Outputs scan results to CISML file.
    """

    def __init__(self, cisml_file, data, ids, pattern_file, sequence_file):
        self.cisml_file = cisml_file
        self.data = data
        self.ids = ids

        # register namespaces
        try:
            register_namespace('', cisml_ns.strip('{}'))
            register_namespace('steme', steme_ns.strip('{}'))
        except NameError:
            pass

        # set up tree
        attrib = {
            'xmlns:xsi': xmlns_xsi,
            'xsi:schemaLocation': xsi_schema_location,
            'xmlns': cisml_ns.strip('{}'),
        }
        self.root = Element('cis-element-search', attrib=attrib)
        program_name = SubElement(self.root, 'program-name')
        program_name.text = 'STEME'
        parameters = SubElement(self.root, 'parameters')
        SubElement(parameters, 'pattern-file').text = pattern_file
        SubElement(parameters, 'sequence-file').text = sequence_file

    def handle_motif(self, motif):
        pattern = SubElement(
            self.root, 'pattern', accession=motif.name, name=motif.name)
        self.scanned_sequences = [
            SubElement(pattern, 'scanned-sequence',
                       accession=_id.split()[0], name=_id) for _id in ids
        ]

    def handle_instance(self, info):
        start, stop = info.pos, info.pos + len(info.wmer) - 1
        if info.instance.rev_comp:
            start, stop = stop, start
        attrib = dict()
        if None != info.pvalue:
            attrib['pvalue'] = str(info.pvalue)
        matched_element = SubElement(
            self.scanned_sequences[info.seq], 'matched-element',
            start=str(start), stop=str(stop), score=str(info.instance.Z),
            **attrib
        )
        sequence = SubElement(matched_element, 'sequence')
        sequence.text = info.wmer

    def finalise(self):
        # print prettify(self.root)
        logging.info('Writing CISML to: %s', self.cisml_file)
        try:
            ElementTree(self.root).write(
                open(self.cisml_file, 'w'),
                xml_declaration=True,
                encoding='utf-8',
                method="xml"
            )
        except TypeError:
            ElementTree(self.root).write(
                open(self.cisml_file, 'w'),
                encoding='utf-8',
            )


def add_pseudo_counts(matrix):
    """Add the pseudo-counts to a matrix.
    """
    values = matrix.values
    values += 1. / (matrix.nsites + 1)
    values[:] = (values.T / values.sum(axis=1)).T


def prettify(elem):
    """Return a pretty-printed XML string for the Element.
    """
    from xml.dom import minidom
    rough_string = tostring(elem, 'utf-8')
    reparsed = minidom.parseString(rough_string)
    return reparsed.toprettyxml(indent="  ")


start_time = time.time()

#
# Parse options and arguments from command line
#
sys.argv = [a.encode(sys.stdin.encoding or 'ascii') for a in sys.argv]
parser = OptionParser()
parser.add_option(
    "--google-profile",
    action="store_true",
    default=False,
    help="Profile with the google profiler."
)
parser.add_option(
    "--back-dist-prior",
    default=1.,
    type='float',
    help="Pseudo-counts for Markov background model."
)
parser.add_option(
    "--bg-model-order",
    dest="bg_model_order",
    default=2,
    type='int',
    help="Order of the background Markov model."
)
parser.add_option(
    "--lambda",
    dest="lambda_",
    default=.001,
    type='float',
    help="Probability of a binding site in the model."
)
parser.add_option(
    "--prediction-Z-threshold",
    default=.3,
    type='float',
    help="The threshold on Z used to find instances of motifs."
)
parser.add_option(
    "-o",
    "--output-dir",
    default='.',
    help="Directory to write output files to."
)
parser.add_option(
    "-b",
    "--bed",
    action='store_true',
    help="Write sites in BED format."
)
parser.add_option(
    "--cache-index",
    action='store_true',
    help="Save the index to disk so it does not need to be rebuilt next time."
)
options, args = parser.parse_args()
stempy.ensure_dir_exists(options.output_dir)


#
# Now we know where we will write output to, start a log file
#
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
log_file_handler = logging.FileHandler(
    os.path.join(options.output_dir, 'steme-pwm-scan.log'), mode='w')
log_file_handler.setFormatter(formatter)
log_file_handler.setLevel(logging.INFO)
logging.getLogger('').addHandler(log_file_handler)


#
# Check we have the correct number of arguments
#
if len(args) != 2:
    raise RuntimeError(
        'USAGE: %s <options> motifs-file fasta-file', sys.argv[0])
motifs_file = args.pop(0)
fasta_file = args.pop(0)


#
# Set up some constants
#
num_best_to_show = 30
cisml_file = 'steme-pwm-scan-cisml.xml'
scan_out_filename = 'steme-pwm-scan.out'
seqs_out_filename = 'steme-pwm-scan.seqs'
bed_out_filename = 'steme-pwm-scan.bed'
cisml_ns = '{http://zlab.bu.edu/schema/cisml}'
steme_ns = '{http://sysbio.mrc-bsu.cam.ac.uk/steme}'
xmlns_xsi = "http://www.w3.org/2001/XMLSchema-instance"
xsi_schema_location = "http://zlab.bu.edu/schema/cisml cisml.xsd"


#
# Load the motifs
#
logging.info('Loading motifs from: %s', motifs_file)
meme_info = stempy.meme_parse.do_parse_and_extract(open(motifs_file).read())
motifs = meme_info.motifs
logging.info('Loaded %d motifs from: %s', len(motifs), motifs_file)


#
# Some basic set up and logging
#
log_options(parser, options)
options.alphabet_size = 4
stempy.ensure_dir_exists(options.output_dir)
stempy.turn_on_google_profiling_if_asked_for(options)
try:
    #
    # Load the sequences
    #
    logging.info('Loading sequences.')
    options.max_w = max(motif.letter_probs.w for motif in motifs)
    seqs = stempy.SequenceSet(fasta_file, options)
    stempy.log_composition(seqs.occs)
    bg_model_mgr = stempy.get_background_manager(seqs, seqs.mm, options)

    #
    # Set up output formats
    #
    outputs = [
        CsvOutput(output_file(scan_out_filename),
                  output_file(seqs_out_filename), seqs.data, seqs.ids),
        # CISMLOutput(output_file(cisml_file), seqs.data, seqs.ids,
        #     motifs_file, fasta_file),
    ]
    if options.bed:
        outputs.append(BedOutput(output_file(bed_out_filename), seqs.ids))

    #
    # For each motif
    #
    for i, motif in enumerate(motifs):
        W = motif.letter_probs.w
        add_pseudo_counts(motif.letter_probs)
        logging.info('Motif %2d (%s) W=%d\n%s', i,
                     motif.name, W, motif.letter_probs.values)
        for output in outputs:
            output.handle_motif(motif)

        #
        # Create the model
        #
        logging.debug('Creating model.')
        bg = bg_model_mgr.get_bg_model(W)
        bs = stempy.PssmBindingSiteModel(
            stempy.initialise_uniform_pssm(W, options.alphabet_size))
        model = stempy.Model(seqs.data, bs, bg, _lambda=options.lambda_)
        model.bs.pssm.log_probs.values()[
            :] = numpy.log(motif.letter_probs.values)
        model.bs.recalculate()

        #
        # Create log odds matrix for p-values
        #
        logging.debug('Calculating log odds.')
        scaled_log_odds, pssm_cdf = stempy.create_log_odds(
            numpy.array([seqs.freqs.freq(i) for i in xrange(4)]),
            numpy.exp(model.bs.pssm.log_probs.values())
        )

        #
        # Create the instance finder and execute it
        #
        logging.debug('Finding instances.')
        instance_finder = stempy.FindInstances(
            seqs.data, model, options.prediction_Z_threshold)
        instance_finder()
        instance_finder.instances.sort()

        #
        # Log what we found
        #
        logging.info(
            'Found %d instances with Z>=%.3f in %d W-mers%s',
            len(instance_finder.instances),
            options.prediction_Z_threshold,
            seqs.data.num_W_mers(W),
            len(instance_finder.instances)
            and ' (%.1f base pairs/instance)' % (
                float(seqs.data.num_W_mers(W)) / len(instance_finder.instances))
                or ''
        )

        #
        # Write instances to outputs
        #
        logging.debug('Writing instances to outputs')
        curried = partial(stempy.instance_info, data=seqs.data, ids=seqs.ids,
                          W=W, scaled_log_odds=scaled_log_odds,
                          pssm_cdf=pssm_cdf)
        for info in imap(curried, instance_finder.instances):
            for output in outputs:
                output.handle_instance(info)
        logging.debug('Wrote instances to outputs')

    for output in outputs:
        output.finalise()

finally:
    stempy.turn_off_google_profiling_if_asked_for(options)
    logging.info('Took %.1f seconds to run', time.time() - start_time)


