#!/usr/bin/env python
#
# Copyright John Reid 2012
#

"""
Uses STEME to scan sequences for PWMs.
"""

import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
import stempy, stempy.meme_parse, numpy, sys
from cookbook.lru_cache import lru_cache



#def bio_matrix_to_numpy(bio_matrix):
#    """Maps a matrix in biopython motif format (a list of dicts mapping bases to entries)
#    into a numpy matrix.
#    """
#    return numpy.array([
#        [x['A'], x['C'], x['G'], x['T']]
#        for x in bio_matrix
#    ])



#
# Parse options and arguments from command line
#
sys.argv = [ a.encode(sys.stdin.encoding or 'ascii') for a in sys.argv ]
options, args = stempy.parse_options(stempy.add_options)


#
# Check we have the correct number of arguments
#
if len(args) != 2:
    raise RuntimeError('USAGE: %s <options> motifs-file fasta-file', sys.argv[0])
motifs_file = args.pop(0)
fasta_file = args.pop(0)


#
# Load the motifs
#
logging.info('Loading motifs from: %s', motifs_file)
meme_info = stempy.meme_parse.do_parse_and_extract(open(motifs_file).read())
motifs = meme_info.motifs


#
# Load the sequences
#
logging.info('Loading sequences.')
num_bases, seqs, ids, index = stempy.read_sequences(fasta_file, options)
occs = stempy.occurrences_from_index(index)
freqs = stempy.ZeroOrderFrequencies(list(occs[:4]))
freqs_with_pseudo_counts = freqs.add_pseudo_counts(options.back_dist_prior)
data = stempy.Data(index, max_W=max(motif.letter_probs.w for motif in motifs))




@lru_cache()
def get_bg_model(W):
    """Return a background model for the specified motif width, W.
    """
    logging.info("Creating background model for W=%d", W)
    mm, _ = stempy.get_markov_model_create_fn(options)(data.index, options.back_dist_prior)
    return stempy.get_bg_model_from_markov_model_fn(options)(W, data, mm, freqs_with_pseudo_counts)



#
# For each motif
#
for i, motif in enumerate(motifs):
    W = motif.letter_probs.w
    logging.info('Motif %2d (%s) W=%d\n%s', i, motif.name, W, motif.letter_probs.values)

    #
    # Create the model
    #
    bg = get_bg_model(W)
    bs = stempy.PssmBindingSiteModel(stempy.initialise_uniform_pssm(W, options.alphabet_size))
    model = stempy.Model(data, bs, bg, _lambda=options.lambda_)
    model.bs.pssm.log_probs.values()[:] = numpy.log(motif.letter_probs.values)
    model.bs.recalculate()

    #
    # Create the instance finder
    #
    logging.info('Finding instances.')
    instance_finder = stempy.FindInstances(data, model, options.prediction_Z_threshold)
    instance_finder()
    instance_finder.instances.sort()

    #
    # Print the instances
    #
    if len(instance_finder.instances) < 100:
        for instance in instance_finder.instances:
            seq, pos = data.pos_localise(instance.global_pos)
            W_mer = data.get_W_mer(W, instance.global_pos)
            if instance.rev_comp:
                W_mer = stempy.reverse_complement(W_mer)
            logging.info('seq=%5d; pos=%6d; strand=%s; W-mer=%s; Z=%4f', seq, pos, instance.rev_comp and '-' or '+', W_mer, instance.Z)
    logging.info(
        'Found %d instances with Z>=%.3f in %d W-mers%s', 
        len(instance_finder.instances),
        options.prediction_Z_threshold,
        data.num_W_mers(W),
        len(instance_finder.instances) 
            and ' (%.1f base pairs/instance)' % (float(data.num_W_mers(W)) / len(instance_finder.instances))
            or ''
    )
    
