#!/usr/bin/env python
#
# Copyright John Reid 2012
#

"""
Finds occurrences of pairs of motifs of a given spacing.

    PRIMARY SECONDARY S U 12
    PRIMARY SECONDARY S D 12
    PRIMARY SECONDARY C U 12
"""

import logging
logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s - %(levelname)s - %(message)s")
from optparse import OptionParser
from collections import defaultdict
import os
import pyicl
import numpy
import pylab
import cPickle
import sys
from Bio import SeqIO
from stempy.scan import parse_occurrences, load_occurrences
from stempy.spacing import footprint, handle_pairs, parse_spacings, \
    spacing_str, load_spacings, add_options, spacing_header, \
    PairOccurrence, make_pair_handler_from_primary_handler


def found_pair_occurrence(spacing, primary, secondary):
    pair_occ_str = '%s: %6d %7d %s' % (
        spacing_str(spacing), primary.seq, primary.pos, primary.strand
    )
    logging.info(pair_occ_str)
    logging.debug('%s %s', primary.wmer, primary.wmer)
    if output:
        print >>output, pair_occ_str
    pair_occurrences.append(
        PairOccurrence(
            spacing=spacing,
            seq=primary.seq,
            pos=primary.pos,
            strand=primary.strand
        )
    )
    if options.seqs:
        record = seqs[primary.seq]
        assert primary.seq == secondary.seq
        first, second = primary, secondary
        if first.pos > second.pos:
            first, second = second, first
        start = first.pos
        end = second.pos+len(second.wmer)
        assert start < end
        pair_seq = record[start:end]
        pair_seq.id = '%s-%010d:%010d' % (record.id, start, end)
        SeqIO.write(pair_seq, pairs_fasta, "fasta")


def handle_primary(max_distance, primary, secondary, distance, upstream, same_strand):
    """Update the spacings array.
    """
    for spacing in spacings[primary.motif, secondary.motif]:
        if spacing.distance == distance and spacing.same_strand == same_strand and spacing.upstream == upstream:
            found_pair_occurrence(spacing, primary, secondary)


pair_header = ': %6s %7s %s' % ('Seq', 'Pos', 'Strand')


parser = OptionParser(
    usage="%prog [options] <spacings file>",
    description="""
Finds pairs of motif occurrences in a PWM scan
that are a predefined distance apart.
The motifs and the distances are defined in the spacings file. Each line
in this file gives one pair of motifs, their orientation and the distance
they should be apart. The fields in each line are primary motif, 
secondary motif, whether the motifs are on the same strand, whether the
secondary occurrence is upstream or downstream of the primary occurrence, 
and the distance between the occurrences. The fields are separated by
whitespace.
""",
)
add_options(parser)
parser.add_option(
    "-o",
    default=None,
    dest='output',
    help="Write output to FILE.",
    metavar="FILE"
)
parser.add_option(
    "-s", '--seqs',
    help="Read sequences from FASTA and output parts corresponding to pairs.",
    metavar="FASTA"
)
options, args = parser.parse_args()
if 1 != len(args):
    parser.print_usage()
    sys.exit(-1)


#
# Now we know where we will write output to, start a log file
#
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
log_file_handler = logging.FileHandler(
    os.path.join(options.results_dir, 'steme-find-spacings-max=%d.log' % options.max_distance), mode='w')
log_file_handler.setFormatter(formatter)
log_file_handler.setLevel(logging.INFO)
logging.getLogger('').addHandler(log_file_handler)


#
# Load the occurrences and associated sequence lengths
#
occurrences, seq_infos, motifs = load_occurrences(options)

#
# Load the file containing the spacings to look for
#
spacings = load_spacings(args[0], motifs, options)
logging.info('')


#
# Open output file
#
output = None
if options.output:
    output_filename = os.path.join(options.results_dir, options.output)
    logging.info('Writing pairs to %s', output_filename)
    output = open(output_filename, 'w')


#
# Read the sequences if asked to
#
if options.seqs:
    seqs = list(SeqIO.parse(open(options.seqs), 'fasta'))
    pairs_fasta = open('pairs.fa', 'w')


#
# Iterate through the occurrences finding spacings
#
logging.info(
    'Examining spacings of up to %d b.p. between %d occurrences of '
        '%d motifs in %d sequences',
    options.max_distance, len(occurrences), len(motifs), len(seq_infos)
)
pair_occurrences = []
logging.info(spacing_header + pair_header)
logging.info('*' * (len(spacing_header) + len(pair_header)))
pair_handler = make_pair_handler_from_primary_handler(
    handle_primary, ignore_close_to_end=False, options=options)
handle_pairs(occurrences, seq_infos, pair_handler, options)


if options.output:
    output.close()

if options.seqs:
    pairs_fasta.close()