#!/usr/bin/env python
from __future__ import print_function
import logging
import argparse
import numpy as np
import numpy.random as rand

from obliquity import kappa_inference as ki
from obliquity.distributions import Cosi_Distribution_FromH5
from simpledist import distributions as dists

def main():
    parser = argparse.ArgumentParser(description='Calculate kappa posterior based on Cos(I) samples provided in file')
    
    parser.add_argument('filename',help='file with list of .h5 files containing Cos(I) distributions (one filename per line)')
    parser.add_argument('--nsamples',default=1000,type=int,help='number of samples to draw from each distribution')
    parser.add_argument('--kmin',default=0.01,type=float,help='minimum value of kappa to calculate.')
    parser.add_argument('--kmax',default=200,type=float,help='maximum value of kappa to calculate.')
    parser.add_argument('--npoints',default=300,type=int,help='number of kappa points to calculate posterior')
    parser.add_argument('--savefile',default=None,help='.h5 file to which to save distribution.  By default, just the filename argument with _kappa.h5 appended.')

    args = parser.parse_args()

    files = np.loadtxt(args.filename,dtype=str)

    all_samples = np.zeros((len(files),args.nsamples))

    bad_inds = []
    for i,f in enumerate(files):
        print('{} of {}'.format(i+1,len(files)))
        try:
            dist = Cosi_Distribution_FromH5(f)
            if len(dist.samples) != args.nsamples:
                inds = rand.randint(len(dist.samples),size=args.nsamples)
                all_samples[i,:] = dist.samples[inds]
            else:
                all_samples[i,:] = dist.samples
        except KeyboardInterrupt:
            raise
        except:
            logging.warning('error with {}; skipping.'.format(f))
            bad_inds.append(i)
    
    all_samples = np.delete(all_samples,bad_inds,axis=0)

    ks,post = ki.kappa_posterior(all_samples,kmin=args.kmin,
                                 kmax=args.kmax,npts=args.npoints)

    dist = dists.Empirical_Distribution(ks,post,name='kappa')

    if args.savefile is None:
        savefile = args.filename + '_kappa.h5'
    else:
        savefile = args.savefile

    dist.save_hdf(savefile)

    med = dist.ppf(0.5)
    lo = dist.ppf(0.5 - .683/2)
    hi = dist.ppf(0.5 + .683/2)

    print('kappa posterior: {:.1f} ({:.1f},{:.1f})'.format(med,lo,hi))
    print('posterior saved to {}'.format(savefile))

if __name__=='__main__':
    main()
