#!python

"""Infers site-specific character preferences.

Written by Jesse Bloom."""


import sys
import os
import multiprocessing
import logging
import time
import cPickle
import dms_tools.utils
import dms_tools.parsearguments
import dms_tools.file_io
import dms_tools.mcmc


def main():
    """Main body of script."""

    # Parse command line arguments
    parser = dms_tools.parsearguments.InferPrefsParser()
    args = vars(parser.parse_args())
    if args['logfile'] == 'Base name of "outfile" with extension ".log"':
        args['logfile'] = "%s.log" % os.path.splitext(args['outfile'])[0]
        if os.path.isfile(args['logfile']):
            os.remove(args['logfile'])
    prog = parser.prog

    # Set up to log everything to logfile.
    # Confusingly, pystan defines its own logger which I have NOT been able to silence.
    # So even though the logger set up in this script will only log to logfile, the
    # pystan one imported by both dms_tools.mcmc and the call to dms_tools.file_io.Versions()
    # will log to standard outpout. This is very confusing, but I have not been able to find
    # a better solution until I determine how to silence the pystan logger.
    versionstring = dms_tools.file_io.Versions() 
    logging.shutdown()
    logger = logging.getLogger(prog)
    logger.setLevel(logging.INFO)
    logfile_handler = logging.FileHandler(args['logfile'])
    logger.addHandler(logfile_handler)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    logfile_handler.setFormatter(formatter)

    # begin execution
    try:
        logger.info('Beginning execution of %s in directory %s\n' % (prog, os.getcwd()))
        logger.info('Progress is being logged to %s\n' % args['logfile'])
        logger.info(versionstring)
        logger.info('Parsed the following arguments:\n%s\n' % '\n'.join(['\t%s = %s' % tup for tup in args.iteritems()]))

        # read in the counts
        if args['chartype'].lower() == 'codon_to_aa':
            chartype = 'codon'
            transform_to_aa = True
            charlength = 3 # length of a codon
            countcharacters = dms_tools.codons
            if args['excludestop']:
                characters = dms_tools.aminoacids_nostop
            else:
                characters = dms_tools.aminoacids_withstop
        elif args['chartype'].lower() == 'codon':
            chartype = 'codons'
            transform_to_aa = False
            countcharacters = characters = dms_tools.codons
            charlength = 3 # length of a codon
        elif args['chartype'].upper() == 'DNA':
            chartype = 'DNA'
            charlength = 1 # length of a nucleotide
            transform_to_aa = False
            countcharacters = characters = dms_tools.nts
        elif args['chartype'].lower() == 'aa':
            transform_to_aa = False
            charlength = 1
            if args['excludestop']:
                chartype = 'aminoacids_nostop'
                characters = countcharacters = dms_tools.aminoacids_nostop
            else:
                chartype = 'aminoacids_withstop'
                characters = countcharacters = dms_tools.aminoacids_withstop
        else:
            raise ValueError("Invalid chartype of %s" % args['chartype'])
        counts_files = {'nrpre':args['n_pre'], 'nrpost':args['n_post']}
        if args['errpre'].lower() == args['errpost'].lower() == 'none':
            pystan_error_model = dms_tools.mcmc.StanModelNoneErr
            error_model = 'none'
        elif args['errpre'] == args['errpost']:
            pystan_error_model = dms_tools.mcmc.StanModelSameErr
            error_model = 'same'
            counts_files['nrerr'] = args['errpre']
        elif args['errpre'].lower() == 'none' or args['errpost'].lower() == 'none':
            raise ValueError("Either --errpre and --errpost must both be 'none', or they must both NOT be 'none'")
        else:
            pystan_error_model = dms_tools.mcmc.StanModelDifferentErr
            error_model = 'different'
            counts_files['nrerrpre'] = args['errpre']
            counts_files['nrerrpost'] = args['errpost']
        (sites, wts, counts) = dms_tools.file_io.ReadMultipleDMSCountsFiles(counts_files.values(), chartype)
        if args['sites']:
            if not (set(args['sites']) < set(sites)):
                raise ValueError("The specified sites are not all present in the DMS counts files. The specified sites are:\n%s" % ' '.join(args['sites']))
            sites = args['sites']
            logger.info('We will perform inference ONLY for the following specified subset of sites:\n\t%s' % ' '.join(sites))
        logger.info('There are counts for %d sites, which are:\n%s\n' % (len(sites), '\n'.join(['\t%s (wildtype identity of %s)' % (r, wts[r]) for r in sites])))
        counts = dict([(name, counts[counts_files[name]]) for name in counts_files.iterkeys()]) # make names count type rather than file

        # compute mutation rates for priors
        if error_model == 'none':
            avgmu = sum([dms_tools.utils.AvgMutRate(counts['nrpre'], chartype)[m] for m in range(1, charlength + 1)]) 
            logger.info('Average per-site mutation rate: %g' % avgmu)
        elif error_model == 'same':
            avgepsilon = dms_tools.utils.AvgMutRate(counts['nrerr'], chartype)
            logger.info('Average per-site error rate:\n%s\n' % '\n'.join(['\t%g for characters that require %d nucleotide changes' % (avgepsilon[m], m) for m in range(1, charlength + 1)]))
            avgmu = sum([dms_tools.utils.AvgMutRate(counts['nrpre'], chartype)[m] for m in range(1, charlength + 1)])  - sum([avgepsilon[m] for m in range(1, charlength + 1)])
            logger.info('Average per-site mutation rate: %g' % avgmu)
        elif error_model == 'different':
            avgepsilon = dms_tools.utils.AvgMutRate(counts['nrerrpre'], chartype)
            logger.info('Average per-site error rate in pre-selection control:\n%s\n' % '\n'.join(['\t%g for characters that require %d nucleotide changes' % (avgepsilon[m], m) for m in range(1, charlength + 1)]))
            avgmu = sum([dms_tools.utils.AvgMutRate(counts['nrpre'], chartype)[m] for m in range(1, charlength + 1)])  - sum([avgepsilon[m] for m in range(1, charlength + 1)])
            logger.info('Average per-site mutation rate: %g\n' % avgmu)
            avgrho = dms_tools.utils.AvgMutRate(counts['nrerrpost'], chartype)
            logger.info('Average per-site error rate in post-selection control:\n%s\n' % '\n'.join(['\t%g for characters that require %d nucleotide changes' % (avgrho[m], m) for m in range(1, charlength + 1)]))
        else:
            raise ValueError("Invalid error_model %s" % error_model)
        if not 0 < avgmu < 1:
            raise ValueError("Average mutation rate of %g is not > 0 and < 1. Perhaps there are no mutations or the error control has a higher mutation rate than the pre-selection library. Either one is an irresolvable problem." % avgmu)

        # set up error model
        if args['ratio_estimation']:
            logger.info('\nPreferences will be computed from enrichment ratios rather than MCMC. Each count will be augmented by a pseudocount of %g.\n' % args['ratio_estimation'])
            pystan_error_model = error_model
        else:
            logger.info('Compiling PyStan model...')
            pystan_error_model = pystan_error_model()
            logger.info('Completed compiling PyStan model.\n')

        # begin inferring preferences, using a multiprocessing pool to enable multiple CPUs
        if args['ncpus'] == -1:
            nprocesses = None
        else:
            nprocesses = args['ncpus']
            assert nprocesses > 0, "Must specify at least one process"
        pool = multiprocessing.Pool(nprocesses)
        results = {}
        retry = {}
        logged = {}
        pi_means = {}
        pi_95credint = {}
        counts_r = {}
        initialincreasetries = 4 # initial tries for MCMC
        retryincreasetries = 6 # retry tries for MCMC
        logger.info('Now beginning inference for each site...\n')
        for r in sites:
            rcounts = dict([(name, counts[name][r]) for name in counts.iterkeys()])
            rcounts = dms_tools.utils.AdjustErrorCounts(wts[r], rcounts, maxexcess=5) # avoid pathologically high error counts, which shouldn't happen if actual data conforms to analysis assumptions
            counts_r[r] = rcounts
            nchars_with_m = dict([(m, 0) for m in range(charlength + 1)])
            for x in countcharacters:
                m = len([i for i in range(charlength) if wts[r][i] != x[i]])
                nchars_with_m[m] += 1
            priors = dict([(name, {}) for name in ['mur_prior_params', 'epsilonr_prior_params', 'rhor_prior_params']])
            priors = {'mur_prior_params':{}}
            if error_model in ['different', 'same']:
                priors['epsilonr_prior_params'] = {}
            if error_model == 'different':
                priors['rhor_prior_params'] = {}
            for x in countcharacters:
                m = len([i for i in range(charlength) if wts[r][i] != x[i]])
                if m == 0: # wildtype
                    priors['mur_prior_params'][x] = 1.0 - avgmu
                else:
                    priors['mur_prior_params'][x] = avgmu / float(len(countcharacters) - 1.0)
                if error_model in ['different', 'same']:
                    priors['epsilonr_prior_params'][x] = avgepsilon[m] / float(nchars_with_m[m])
                if error_model == 'different':
                    priors['rhor_prior_params'][x] = avgrho[m] / float(nchars_with_m[m])
            for (name, d) in priors.iteritems():
                assert (not d) or (abs(sum(d.values()) - 1.0) < 1.0e-6), "Sum not close to one for %s: %g" % (name, sum(d.values()))
            if transform_to_aa:
                wts[r] = dms_tools.codon_to_aa[wts[r]]
                for (name, d) in priors.items():
                    if not d:
                        continue
                    priors[name] = dms_tools.utils.SumCodonToAA(d, not args['excludestop'])
                    entrysum = sum(priors[name].values()) # rescale to sum to one
                    assert entrysum > 0, "All of sum was stop codons?"
                    priors[name] = dict([(aa, x / float(entrysum)) for (aa, x) in priors[name].iteritems()])
                for (name, d) in rcounts.items():
                    rcounts[name] = dms_tools.utils.SumCodonToAA(d, not args['excludestop'])
            for (name, alpha) in [('mur_prior_params', 'mu_alpha'), ('epsilonr_prior_params', 'err_alpha'), ('rhor_prior_params', 'err_alpha')]:
                if name in priors:
                    priors[name] = dict([(x, priors[name][x] * len(characters) * args[alpha]) for x in characters])
            priors['pir_prior_params'] = dict([(x, args['pi_alpha']) for x in characters])
            logged[r] = False
            if args['ratio_estimation']:
                results[r] = pool.apply_async(dms_tools.mcmc.InferSitePreferencesFromEnrichmentRatios, (characters, wts[r], pystan_error_model, rcounts, args['ratio_estimation']))
            else:
                niter = {'none':2500, 'same':10000, 'different':20000}[error_model] # these tend to be good values to converge the first time
                results[r] = pool.apply_async(dms_tools.mcmc.InferSitePreferences, (characters, wts[r], pystan_error_model, rcounts, priors, args['seed'], niter, initialincreasetries))
                retry[r] = (characters, wts[r], pystan_error_model, rcounts, priors, args['seed'] + 1, niter, retryincreasetries)
        while not all([logged[r] for r in sites]):
            time.sleep(1) # wait a bit before polling again
            for r in sites:
                sitefailed = False
                if results[r].ready() and not logged[r]:
                    logger.info('Getting results for site %s...' % r)
                    try:
                        (converged, pi, pi95, logstring) = results[r].get()
                    except:
                        if r in retry:
                            sitefailed = True
                        else:
                            logger.error('Found an error even after a retry for site %s. Here is summary of attempt:\n%s\n' % (r, logstring))
                            with open('_debugging_counts_%s.pickle' % r, 'w') as f:
                                cPickle.dump(counts_r[r], f) # information potentially useful for debugging
                            raise
                    if (sitefailed or not converged) and r in retry:
                        logger.warning('Problems with inference for site %s!!! Performing the retry. Here is summary of previous attempt:\n%s\n' % (r, logstring))
                        results[r] = pool.apply_async(dms_tools.mcmc.InferSitePreferences, retry[r])
                        del retry[r]
                    else:
                        logged[r] = True
                        logger.info('Finished inference for site %s. Here is the summary:\n%s\n' % (r, logstring))
                        if not converged:
                            logger.error('Failed to converge for site %s. Here is a summary of the attempt:\n%s\n' % (r, logstring))
                            with open('_debugging_counts_%s.pickle' % r, 'w') as f:
                                cPickle.dump(counts_r[r], f) # information potentially useful for debugging
                            raise RuntimeError("Preference inference MCMC failed to converge for site %s:\n%s" % (r, logstring))
                        else:
                            logger.info('Finished inference for site %s. Here is the summary:\n%s\n' % (r, logstring))
                        pi_means[r] = pi
                        pi_95credint[r] = pi95
        pool.terminate()
        logger.info('Now writing preferences to %s\n' % args['outfile'])
        dms_tools.file_io.WritePreferences(args['outfile'], sites, wts, pi_means, pi_95credint)
    except:
        logger.exception('Terminating %s at %s with ERROR' % (prog, time.asctime()))
    else:
        logger.info('Successful completion of %s at %s' % (prog, time.asctime()))
    finally:
        logging.shutdown()



if __name__ == '__main__':
    main() # run the script
