#!/usr/bin/env python
#
# Copyright John Reid 2012
#

"""
Uses STEME to scan sequences for PWMs.
"""

import logging
logging.basicConfig(level=logging.INFO)
import stempy, numpy, sys
from Bio import Motif


def bio_matrix_to_numpy(bio_matrix):
    """Maps a matrix in biopython motif format (a list of dicts mapping bases to entries)
    into a numpy matrix.
    """
    return numpy.array([
        [x['A'], x['C'], x['G'], x['T']]
        for x in bio_matrix
    ])



#
# Parse options and arguments from command line
#
sys.argv = [ a.encode(sys.stdin.encoding or 'ascii') for a in sys.argv ]
options, args = stempy.parse_options(stempy.add_options)


#
# Check we have the correct number of arguments
#
if len(args) != 2:
    raise RuntimeError('USAGE: %s <options> motifs-file fasta-file', sys.argv[0])
motifs_file = args.pop(0)
fasta_file = args.pop(0)


#
# Load the motifs
#
logging.info('Loading motifs from: %s', motifs_file)
motifs = list(Motif.parse(open(motifs_file),"MEME"))


#
# Load the sequences
#
input_sequences = stempy.SequenceSet(fasta_file, options)


#
# Initialise the background
#
mm, freqs, freqs_with_pseudo_counts = input_sequences.build_model_of()
input_sequences.calculate_likelihoods(mm)
input_sequences.lls[0][0]



#
# For each motif
#
for i, motif in enumerate(motifs):
    logging.info('Motif %2d consensus: %s', i, motif.consensus())

    #
    # Create the model
    #
    model = input_sequences.create_model(motif.length)
    model.bs.pssm.log_probs.values()[:] = numpy.log(bio_matrix_to_numpy(motif.pwm()))

    #
    # Create the instance finder
    #
    instance_finder = stempy.FindInstances(input_sequences.data, model, .3)
    instance_finder()
    instance_finder.instances.sort()
    
    for instance in instance_finder.instances:
        seq, pos = input_sequences.data.pos_localise(instance.global_pos)
        W_mer = input_sequences.data.get_W_mer(motif.length, instance.global_pos)
        if instance.rev_comp:
            W_mer = stempy.reverse_complement(W_mer)
        logging.info('seq=%5d; pos=%6d; strand=%s; W-mer=%s; Z=%4f', seq, pos, instance.rev_comp and '-' or '+', W_mer, instance.Z)
    
