#!/usr/bin/env python
#
# Copyright John Reid 2012
#

"""
Reads sequences and makes into a matrix. Example input:

    CCCCTAATCCGCT
    CCCCTAATCCGCT
    CCCCTAATCCGCT
    CCCCTAATCCGCT
    CCCCTAATCCGCT
    CCCCTAATCCGCT
    CCCCTAATCCGCT
    CCCCTAATCCGCT
    ACCCTAATCCGTT
    ACGCTAATCCGTT
    ACGCTAATCCGTT
    ACGCTAATCCGTA
    ACGCTAATCCGTA
    ACGCTAATCCCTA
    ACGCTAATCCCTA
    AAGCTAATCCCTA
    GAGTTAATCCCAA
    GAGTTAATCCCAG
    GAATTAATCCCAG
    GAATTAATCCCAG
    GAATTAATCCCAG
    GGATTAATCCCAG
    TGAATAAGCCCGC
    TGAATAAGCCTGC
    TGTATAAGCATGC
    TTTATAAGCATGC
    TTTGTAAGCGAGC

would produce:

    2012-07-02 12:13:44,092 - INFO - Using pseudo-count of 1.0
    2012-07-02 12:13:44,092 - INFO - Reading sequences from filename: otd.seqs
    2012-07-02 12:13:44,095 - INFO - Counts (without pseudo-counts):
    [[ 8  8  6  5]
     [ 6 15  4  2]
     [ 6  9  9  3]
     [ 4 16  1  6]
     [ 0  0  0 27]
     [27  0  0  0]
     [27  0  0  0]
     [ 0  0  5 22]
     [ 0 27  0  0]
     [ 2 24  1  0]
     [ 1 10 13  3]
     [ 6  8  5  8]
     [ 6  5  5 11]]
    2012-07-02 12:13:44,096 - INFO - Probabilities (including pseudo-counts):
    [[ 0.29032258  0.29032258  0.22580645  0.19354839]
     [ 0.22580645  0.51612903  0.16129032  0.09677419]
     [ 0.22580645  0.32258065  0.32258065  0.12903226]
     [ 0.16129032  0.5483871   0.06451613  0.22580645]
     [ 0.03225806  0.03225806  0.03225806  0.90322581]
     [ 0.90322581  0.03225806  0.03225806  0.03225806]
     [ 0.90322581  0.03225806  0.03225806  0.03225806]
     [ 0.03225806  0.03225806  0.19354839  0.74193548]
     [ 0.03225806  0.90322581  0.03225806  0.03225806]
     [ 0.09677419  0.80645161  0.06451613  0.03225806]
     [ 0.06451613  0.35483871  0.4516129   0.12903226]
     [ 0.22580645  0.29032258  0.19354839  0.29032258]
     [ 0.22580645  0.19354839  0.19354839  0.38709677]]
    2012-07-02 12:13:44,097 - INFO - Log probabilities (including pseudo-counts):
    [[-1.23676263 -1.23676263 -1.48807706 -1.64222774]
     [-1.48807706 -0.66139848 -1.82454929 -2.33537492]
     [-1.48807706 -1.13140211 -1.13140211 -2.04769284]
     [-1.82454929 -0.60077386 -2.74084002 -1.48807706]
     [-3.4339872  -3.4339872  -3.4339872  -0.10178269]
     [-0.10178269 -3.4339872  -3.4339872  -3.4339872 ]
     [-0.10178269 -3.4339872  -3.4339872  -3.4339872 ]
     [-3.4339872  -3.4339872  -1.64222774 -0.29849299]
     [-3.4339872  -0.10178269 -3.4339872  -3.4339872 ]
     [-2.33537492 -0.21511138 -2.74084002 -3.4339872 ]
     [-2.74084002 -1.03609193 -0.79492987 -2.04769284]
     [-1.48807706 -1.23676263 -1.64222774 -1.23676263]
     [-1.48807706 -1.64222774 -1.64222774 -0.94908055]]
"""


import logging
logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s - %(levelname)s - %(message)s")
import sys
import numpy

pseudo_count = 1.
logging.info('Using pseudo-count of %.1f', pseudo_count)

for idx, filename in enumerate(sys.argv[1:]):
    counts = None
    for seq in open(filename):
        seq = seq.strip()
        if None == counts:
            counts = numpy.zeros((len(seq), 4), dtype=numpy.uint)
        else:
            assert len(counts) == len(seq)
        for i, base in enumerate(seq):
            if 'A' == base or 'a' == base:
                counts[i][0] += 1
            elif 'C' == base or 'c' == base:
                counts[i][1] += 1
            elif 'G' == base or 'g' == base:
                counts[i][2] += 1
            elif 'T' == base or 't' == base:
                counts[i][3] += 1
            else:
                raise RuntimeError('Unknown base: %s' % base)
    num_seqs = counts[0].sum()
    logging.info('Read %d sequences from filename: %s', num_seqs, filename)
    logging.info('Counts (without pseudo-counts):\n%s', counts)
    counts = counts.astype(numpy.float) + pseudo_count
    probs = (counts.T / counts.sum(axis=1)).T
    logging.info('Probabilities (including pseudo-counts):\n%s', probs)
    logging.info(
        'Log probabilities (including pseudo-counts):\n%s', numpy.log(probs))
    meme_filename = '%s.meme' % filename
    logging.info('Writing motif to: %s', meme_filename)
    with open(meme_filename, 'w') as meme_out:
        print >>meme_out, """
MEME version 4
    
ALPHABET= ACGT
    
strands: + -
    
Background letter frequencies
A 0.25 C 0.25 G 0.25 T 0.25

MOTIF motif-name
letter-probability matrix: alength= 4 w= %d nsites= %d
%s
""" % (
            len(counts),
            num_seqs,
            '\n'.join(' '.join(map(str, p)) for p in probs)
        )
