#!/usr/bin/env python

# RSD: The reciprocal smallest distance algorithm.
#   Wall, D.P., Fraser, H.B. and Hirsh, A.E. (2003) Detecting putative orthologs, Bioinformatics, 19, 1710-1711.
# Original Author: Dennis P. Wall, Department of Biological Sciences, Stanford University.
# Author: Todd F. DeLuca, Center for Biomedical Informatics, Harvard Medical School
# Contributors: I-Hsien Wu, Computational Biology Initiative, Harvard Medical School

import argparse
import os
import shutil

import rsd
import rsd.nested
import rsd.orthutil


class DivEvalueCollector(object):
    '''
    An object of this class is used, during processing of command line arguments using argparse,
    as the type argument for divergence and evalue threshold pairs specified on the command line.
    The object validates and collects the divergence and evalue arguments.
    '''

    def __init__(self, parser):
        '''
        parser: an argparse.ArgumentParser, used to raise errors when arguments are invalid divergence or evalue thresholds.
        '''
        self.divs = []
        self.evalues = []
        self.looking_for_div = True
        self.parser = parser
        
    def __call__(self, arg):
        '''
        arg: a string argument from the command line.  Should be a divergence or evalue threshold.
        Appends divergences to divs list and evalues to evalues list.
        Raises a parser error with a helpful message if s is not a float or is not in the right range for a divergence or evalue threshold.
        '''
        if self.looking_for_div:
            try:
                div = float(arg)
                assert div > 0.0 and div < 1.0
                self.divs.append(div)
            except:
                self.parser.error('argument --de: A divergence threshold must be a number > 0.0 and < 1.0.  You gave "{}" instead.'.format(arg))
        else:
            try:
                evalue = float(arg)
                assert evalue >= 0.0
                self.evalues.append(evalue)
            except:
                self.parser.error('argument --de: An evalue threshold must be a number >= 0.0.  You gave "{}" instead.'.format(arg))
        # alternate looking for div and looking for evalue b/c we want (div, evalue) pairs.
        self.looking_for_div = not self.looking_for_div

        
def main():

    parser = argparse.ArgumentParser(description='Compute orthologs using the reciprocal smallest distance (RSD) algorithm between the query genome and the subject genome.  See "Detecting putative orthologs", Wall DP, Fraser HB, Hirsh AE, Bioinformatics, 2003, http://bioinformatics.oxfordjournals.org/content/19/13/1710 for a description of the algorithm.')
    parser.add_argument('-q', '--query-genome', required=True, help='FASTA format protein sequence file, with unique ids on each nameline either in the form ">id" or ">namespace|id|...".')
    parser.add_argument('-s', '--subject-genome', required=True, help='FASTA format protein sequence file, with unique ids on each nameline either in the form ">id" or ">namespace|id|...".')
    parser.add_argument('-o', '--outfile', required=True, help='File in which to write orthologs.')
    parser.add_argument('-f', '--forward-hits', help='File containing forward blast hits.  If not given, hits will be computed.  Using rsd_blast to create this file can save time if running rsd_search multiple times, especically for larger genomes.')
    parser.add_argument('-r', '--reverse-hits', help='File containing reverse blast hits.  If not given, hits will be computed.  Using rsd_blast to create this file can save time if running rsd_search multiple times, especically for larger genomes.')
    # parser.add_argument('-d', '--divergence', type=float, default='0.8', help='Theshold for the maximum divergence allowed between a query and subject sequence.  A number > 0 and < 1. e.g. 0.2, or 0.5.  Default is 0.8.')
    # parser.add_argument('-e', '--evalue', type=float, default='1e-5', help='Theshold for the maximum BLAST e-value allowed between a query and subject sequence.  e.g. 1e-20, or 0.005.  Default is 1e-5.')
    parser.add_argument('--ids', help='Path to file containing seq ids (one per line) in query_genome for which to compute orthologs.  If you only have one or a few sequences of interest it can be much faster to limit computation to those sequences.  The default is to compute othologs for all sequences in query_genome.  The sequence ids in the file must correspond to ids on the fasta namelines of query_genome.')
    parser.add_argument('--no-blast-cache', default=False, action='store_true', help='If this option is given, blast hits will not be precomputed for every sequence in each genome.  Using this option Can be faster if computing orthologs for only a few sequences.  Consider using in conjunction with --ids.')
    parser.add_argument('--no-format', default=False, action='store_true', help='If this option is given, genome fasta files will not be formatted for blast.  This is useful if blast formatted indices already exist and are located in the same directory as the fasta genome files.')
    parser.add_argument('--workdir', default='.', help='Directory under which to work.  will create a subdirectory under this dir in which to write temporary files, etc.  This subdirectory will be removed when rsd finishes.  Default is %(default)s')
    parser.add_argument('-v', '--verbose', default=False, action='store_true')
    parser.add_argument('--outfmt', type=int, default=-1, choices=(-1, 1, 2, 3), help='''Output format.  Default: %(default)s.  Format -1 is synonymous with the highest format number.  Format 1 outputs one ortholog per line, as subject_sequence_id (aka sid), query_sequence_id (aka qid), and maximum likelihood distance (aka dist), separated by tabs.  This was the original output format of RSD from the code referenced in the (Wall et al. 2003) paper cited above.  Format 2 is outputs one ortholog per line, as qid, sid, dist, separated by tabs.  By convention, Roundup (http://roundup.hms.harvard.edu), a large RSD-based orthology database, orders the query genome before the subject genome, making the columns of format 2 consistent with that ordering.  In format 3, inspired by Uniprot dat files, a set of orthologs starts with a line listing the parameters (query genome, subject genome, divergence, and evalue) used to compute the orthologs, then has 0 or more ortholog lines listing the qid, sid, and dist of each ortholog, and ends with a closing line.  Unlike formats 1 and 2, format 3 can both represent a set of parameters that have no detected orthologs and serialize orthologs for multiple parameter combinations.  Example: PA\\tLACJO\\tYEAS7\\t0.2\\t1e-15\\nOR\\tQ74IU0\\tA6ZM40\\t1.7016\\nOR\\tQ74K17\\tA6ZKK5\\t0.8215\\n//\\n  For these reasons, format 3 is recommended.  Formats 1 and 2 are available for backward compatibility.  It is an error to specify output format 1 or 2 and multiple parameter combinations with --de.''')
    de = DivEvalueCollector(parser)
    parser.add_argument('--de', nargs=2, type=de, metavar=('DIVERGENCE', 'EVALUE'), help="Specify a divergence and an evalue threshold for ortholog computation.  Default: '--de 0.8 1e-5'.  This option can be used multiple times, and orthologs will be computed for each unique divergence and evalue pair.  Orthologs are computed in one pass, so this is significantly faster than running RSD once for each pair.  DIVERGENCE is a number > 0.0 and < 1.0, e.g. 0.2 or 0.5, which is the theshold for the maximum divergence allowed between a query and subject sequence.  EVALUE is a number >= 0.0, e.g. 1e-20 or 1.0, which is the theshold for the maximum BLAST e-value allowed between a query and subject sequence.")
    args = parser.parse_args()

    # paranoid check: if the lengths are different, we somehow got more evalues or divergences, even though the nargs parameter to the --de argument
    # guarantees we get pairs and DivEvalueCollector guarantees they alternate like div, evalue, div, evalue, etc.
    assert len(de.divs) == len(de.evalues)
    
    # Make a unique list of divEvalue pairs, using the numbers, so '.9' == '0.9'.
    # Sorted b/c set() messes up the original order anyway and I am too lazy to write an order-maintaining function to get unique elements.
    divEvalues = sorted(set(zip(de.divs, de.evalues) if de.divs else [(0.8, 1e-5)]))
    maxEvalue = max(float(evalue) for div, evalue in divEvalues)    
    
    if len(divEvalues) > 1 and args.outfmt in (1,2):
        parser.error('It is an error to specify output format 1 or 2 and multiple parameter combinations with --de.  Consider using "--outfmt 3"')

    queryGenome = os.path.abspath(os.path.expanduser(args.query_genome))
    subjectGenome = os.path.abspath(os.path.expanduser(args.subject_genome))
    outfile = os.path.abspath(os.path.expanduser(args.outfile))
    
    if args.ids:
        with open(os.path.abspath(os.path.expanduser(args.ids))) as fh:
            # one id per line.  ignore blank lines and comment lines
            ids = [i for i in (line.strip() for line in fh) if i and not i.startswith('#')] 
    else:
        ids = None

    if args.verbose:
        print 'query sequence ids:', ids
        print 'divergence and evalue pairs:', divEvalues


    # if format, copy fasta files to tmp dir.  format them.  use tmp files
    # if no format, use given files.

    # if no_cache, ...
    #   use on the fly hits to compute orthologs
    # else:
    #   if forward_hits, use them, else precompute forward hits in temp area.
    #   if reverse_hits, ...
    #   use hits to compute orthologs
    
    # destDir = args.dir if args.dir else os.path.basename(srcFasta)
    # destFasta = rsd.copyFastaArg(args.genome, destDir, verbose=args.verbose)
    # rsd.formatFastaArg(destFasta, verbose=args.verbose)
    
    with rsd.nested.NestedTempDir(dir=os.path.abspath(os.path.expanduser(args.workdir)), nesting=0) as tmpDir:

        # format fasta files if needed.
        if args.no_format:
            # assume blast formatted index files coexist with the fasta files
            queryFastaPath = queryGenome
            subjectFastaPath = subjectGenome
        else:
            if args.verbose:
                print 'copying fasta files'
            queryFastaPath = rsd.copyFastaArg(queryGenome, tmpDir)
            subjectFastaPath = rsd.copyFastaArg(subjectGenome, tmpDir)
            if args.verbose:
                print 'formatting fasta files'
            rsd.formatFastaArg(queryFastaPath)
            rsd.formatFastaArg(subjectFastaPath)
            
        if args.no_blast_cache: # compute orthologs on-the-fly (i.e. without computing blast hits for every sequence)
            getForwardHits = rsd.makeGetHitsOnTheFly(subjectFastaPath, maxEvalue, tmpDir)
            getReverseHits = rsd.makeGetHitsOnTheFly(queryFastaPath, maxEvalue, tmpDir)
            if args.verbose:
                print 'computing orthologs'
            divEvalueToOrthologs = rsd.computeOrthologsUsingOnTheFlyHits(queryFastaPath, subjectFastaPath, divEvalues, ids, tmpDir)
        else: # compute orthologs using computed blast hits.
            if args.forward_hits:
                forwardHitsPath = os.path.abspath(os.path.expanduser(args.forward_hits))
            else: # compute blast hits
                if args.verbose:
                    print 'computing forward blast hits'
                forwardHitsPath = os.path.join(tmpDir, 'query_subject.blast.hits.pickle')
                rsd.computeBlastHits(queryFastaPath, subjectFastaPath, forwardHitsPath, maxEvalue, workingDir=tmpDir)
            if args.reverse_hits:
                reverseHitsPath = os.path.abspath(os.path.expanduser(args.reverse_hits))
            else: # compute blast hits
                if args.verbose:
                    print 'computing reverse blast hits'
                reverseHitsPath = os.path.join(tmpDir, 'subject_query.blast.hits.pickle')
                rsd.computeBlastHits(subjectFastaPath, queryFastaPath, reverseHitsPath, maxEvalue, workingDir=tmpDir)

            if args.verbose:
                print 'computing orthologs'
            divEvalueToOrthologs = rsd.computeOrthologsUsingSavedHits(queryFastaPath, subjectFastaPath, divEvalues, forwardHitsPath, reverseHitsPath, ids, tmpDir)

        # write out orthologs
        with open(outfile, 'w') as fh:
            for divEvalue in divEvalues:
                orthologs = divEvalueToOrthologs[divEvalue]
                if args.verbose:
                    print 'writing {0} orthologs to outfile'.format(len(orthologs))
                if args.outfmt == 1: # write out orthologs as sid, qid, dist.
                    rsd.orthutil.orthologsToStream(orthologs, fh, 1)
                elif args.outfmt == 2: # write out orthologs as qid, sid, dist.
                    rsd.orthutil.orthologsToStream(orthologs, fh, 2)
                elif args.outfmt == 3 or args.outfmt == -1:
                    queryGenome = os.path.basename(queryFastaPath)
                    subjectGenome = os.path.basename(subjectFastaPath)
                    div, evalue = divEvalue
                    orthDatas = [((queryGenome, subjectGenome, div, evalue), orthologs)]
                    rsd.orthutil.orthDatasToStream(orthDatas, fh)
                    

if __name__ == '__main__':
   main()

   
# last line
