#!python

"""Infers site-specific differential character preferences.

Written by Jesse Bloom."""


import sys
import os
import multiprocessing
import logging
import time
import dms_tools.utils
import dms_tools.parsearguments
import dms_tools.file_io
import dms_tools.mcmc


def main():
    """Main body of script."""

    # Parse command line arguments
    parser = dms_tools.parsearguments.InferDiffPrefsParser()
    args = vars(parser.parse_args())
    if args['logfile'] == 'Base name of "outfile" with extension ".log"':
        args['logfile'] = "%s.log" % os.path.splitext(args['outfile'])[0]
        if os.path.isfile(args['logfile']):
            os.remove(args['logfile'])
    prog = parser.prog

    # Set up to log everything to logfile.
    # Confusingly, pystan defines its own logger which I have NOT been able to silence.
    # So even though the logger set up in this script will only log to logfile, the
    # pystan one imported by both dms_tools.mcmc and the call to dms_tools.file_io.Versions()
    # will log to standard outpout. This is very confusing, but I have not been able to find
    # a better solution until I determine how to silence the pystan logger.
    versionstring = dms_tools.file_io.Versions() 
    logging.shutdown()
    logger = logging.getLogger(prog)
    logger.setLevel(logging.INFO)
    logfile_handler = logging.FileHandler(args['logfile'])
    logger.addHandler(logfile_handler)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    logfile_handler.setFormatter(formatter)

    # begin execution
    try:
        logger.info('Beginning execution of %s in directory %s\n' % (prog, os.getcwd()))
        logger.info('Progress is being logged to %s\n' % args['logfile'])
        logger.info(versionstring)
        logger.info('Parsed the following arguments:\n%s\n' % '\n'.join(['\t%s = %s' % tup for tup in args.iteritems()]))

        # read in the counts
        if args['chartype'].lower() == 'codon_to_aa':
            chartype = 'codon'
            transform_to_aa = True
            charlength = 3 # length of a codon
            countcharacters = dms_tools.codons
            if args['excludestop']:
                characters = dms_tools.aminoacids_nostop
            else:
                characters = dms_tools.aminoacids_withstop
        elif args['chartype'].lower() == 'codon':
            chartype = 'codons'
            transform_to_aa = False
            countcharacters = characters = dms_tools.codons
            charlength = 3 # length of a codon
        elif args['chartype'].upper() == 'DNA':
            chartype = 'DNA'
            charlength = 1 # length of a nucleotide
            transform_to_aa = False
            countcharacters = characters = dms_tools.nts
        elif args['chartype'].lower() == 'aa':
            transform_to_aa = False
            charlength = 1
            if args['excludestop']:
                chartype = 'aminoacids_nostop'
                characters = countcharacters = dms_tools.aminoacids_nostop
            else:
                chartype = 'aminoacids_withstop'
                characters = countcharacters = dms_tools.aminoacids_withstop
        else:
            raise ValueError("Invalid chartype of %s" % args['chartype'])
        counts_files = {'nrstart':args['n_start'], 'nrs1':args['n_s1'], 'nrs2':args['n_s2']}
        if args['err'].lower() == 'none':
            pystan_error_model = dms_tools.mcmc.StanModelDiffPrefNoErr
            error_model = 'none'
        else:
            pystan_error_model = dms_tools.mcmc.StanModelDiffPrefSameErr
            error_model = 'same'
            counts_files['nrerr'] = args['err']
        (sites, wts, counts) = dms_tools.file_io.ReadMultipleDMSCountsFiles(counts_files.values(), chartype)
        if args['sites']:
            if not (set(args['sites']) < set(sites)):
                raise ValueError("The specified sites are not all present in the DMS counts files. The specified sites are:\n%s" % ' '.join(args['sites']))
            sites = args['sites']
            logger.info('We will perform inference ONLY for the following specified subset of sites:\n%s' % ' '.join(sites))
        logger.info('There are counts for %d sites, which are:\n%s\n' % (len(sites), '\n'.join(['\t%s (wildtype identity of %s)' % (r, wts[r]) for r in sites])))
        counts = dict([(name, counts[counts_files[name]]) for name in counts_files.iterkeys()]) # make names count type rather than file

        # compute mutation rates for priors
        if error_model == 'none':
            avgfstart = sum([dms_tools.utils.AvgMutRate(counts['nrstart'], chartype)[m] for m in range(1, charlength + 1)]) 
            logger.info('Average per-site mutation frequency in starting library: %g' % avgfstart)
        elif error_model == 'same':
            avgxi = dms_tools.utils.AvgMutRate(counts['nrerr'], chartype)
            logger.info('Average per-site error rate:\n%s\n' % '\n'.join(['\t%g for characters that require %d nucleotide changes' % (avgxi[m], m) for m in range(1, charlength + 1)]))
            avgfstart = sum([dms_tools.utils.AvgMutRate(counts['nrstart'], chartype)[m] for m in range(1, charlength + 1)])  - sum([avgxi[m] for m in range(1, charlength + 1)])
            logger.info('Average per-site mutation frequency in starting library: %g' % avgfstart)
        else:
            raise ValueError("Invalid error_model %s" % error_model)
        if not 0 < avgfstart < 1:
            raise ValueError("Average mutation frequency in starting library of %g is not > 0 and < 1. Perhaps there are no mutations or the error control has a higher mutation rate than the starting library. Either one is an irresolvable problem." % avgfstart)

        # begin inferring preferences, using a multiprocessing pool to enable multiple CPUs
        logger.info('Compiling PyStan model...')
        pystan_error_model = pystan_error_model()
        logger.info('Completed compiling PyStan model.\n')
        if args['ncpus'] == -1:
            nprocesses = None
        else:
            nprocesses = args['ncpus']
            assert nprocesses > 0, "Must specify at least one process"
        pool = multiprocessing.Pool(nprocesses)
        results = {}
        logged = {}
        deltapi_means = {}
        pr_deltapi_gt0 = {}
        pr_deltapi_lt0 = {}
        logger.info('Now beginning inference for each site...\n')
        for r in sites:
            rcounts = dict([(name, counts[name][r]) for name in counts.iterkeys()])
            rcounts = dms_tools.utils.AdjustErrorCounts(wts[r], rcounts, maxexcess=5) # avoid pathologically high error counts, which shouldn't happen if actual data conforms to analysis assumptions
            nchars_with_m = dict([(m, 0) for m in range(charlength + 1)])
            for x in countcharacters:
                m = len([i for i in range(charlength) if wts[r][i] != x[i]])
                nchars_with_m[m] += 1
            priors = dict([(name, {}) for name in ['frstart_prior_params', 'xir_prior_params']])
            for x in countcharacters:
                m = len([i for i in range(charlength) if wts[r][i] != x[i]])
                if m == 0: # wildtype
                    priors['frstart_prior_params'][x] = 1.0 - avgfstart
                else:
                    priors['frstart_prior_params'][x] = avgfstart / float(len(countcharacters) - 1.0)
                if error_model == 'same':
                    priors['xir_prior_params'][x] = avgxi[m] / float(nchars_with_m[m])
            for (name, d) in priors.iteritems():
                assert (not d) or (abs(sum(d.values()) - 1.0) < 1.0e-6), "Sum not close to one for %s: %g" % (name, sum(d.values()))
            if transform_to_aa:
                wts[r] = dms_tools.codon_to_aa[wts[r]]
                for (name, d) in priors.items():
                    if not d:
                        continue
                    priors[name] = dms_tools.utils.SumCodonToAA(d, not args['excludestop'])
                    entrysum = sum(priors[name].values()) # rescale to sum to one
                    assert entrysum > 0, "All of sum was stop codons?"
                    priors[name] = dict([(aa, x / float(entrysum)) for (aa, x) in priors[name].iteritems()])
                for (name, d) in rcounts.items():
                    rcounts[name] = dms_tools.utils.SumCodonToAA(d, not args['excludestop'])
            for (name, alpha) in [('frstart_prior_params', 'alpha_start'), ('xir_prior_params', 'alpha_err')]:
                if priors[name]:
                    priors[name] = dict([(x, priors[name][x] * len(characters) * args[alpha]) for x in characters])
            priors['pirs1_prior_params'] = dict([(x, args['alpha_pis1']) for x in characters])
            priors['deltapi_concentration'] = args['alpha_deltapi']
            logged[r] = False
            niter = {'none':1500, 'same':8000}[error_model]
            results[r] = pool.apply_async(dms_tools.mcmc.InferSiteDiffPrefs, (characters, wts[r], pystan_error_model, rcounts, priors, args['seed'], niter))
        while not all([logged[r] for r in sites]):
            time.sleep(1) # wait a bit before polling again
            for r in sites:
                if results[r].ready() and not logged[r]:
                    logged[r] = True
                    logger.info('Getting results for site %s...' % r)
                    (converged, deltapi, pr_gt0, pr_lt0, logstring) = results[r].get()
                    logger.info('Finished inference for site %s. Here is the summary:\n%s\n' % (r, logstring))
                    if not converged:
                        raise RuntimeError("Differential preference inference MCMC failed to converge for site %s:\n%s" % (r, logstring))
                    deltapi_means[r] = deltapi
                    pr_deltapi_gt0[r] = pr_gt0
                    pr_deltapi_lt0[r] = pr_lt0
        pool.terminate()
        logger.info('Now writing differential preferences to %s\n' % args['outfile'])
        dms_tools.file_io.WriteDiffPrefs(args['outfile'], sites, wts, deltapi_means, pr_deltapi_lt0, pr_deltapi_gt0)
    except:
        logger.exception('Terminating %s at %s with ERROR' % (prog, time.asctime()))
    else:
        logger.info('Successful completion of %s at %s' % (prog, time.asctime()))
    finally:
        logging.shutdown()



if __name__ == '__main__':
    main() # run the script
