#!/usr/bin/env python
from __future__ import print_function
from obliquity.distributions import Cosi_Distribution
from simpledist import distributions as dists

import argparse

import matplotlib.pyplot as plt

def main():
    parser = argparse.ArgumentParser(description='Calculate cosi distribution from provided Rstar, Prot, Vsini; write distribution to file')

    parser.add_argument('filename',help='output filename (should end in .h5)')
    parser.add_argument('-R',nargs='+',help='Radius distribution (can be .h5 file with saved distribution, R dR, or R dR- dR+')
    parser.add_argument('-P',nargs='+',help='Prot distribution (can be .h5 file with saved distribution, P dP, or P dP- dP+')
    parser.add_argument('-V',nargs='+',help='Vsini distribution (can be .h5 file with saved distribution, V dV, or V dV- dV+')

    parser.add_argument('--nsamples',type=int,default=1000,help='number of cosi samples to save')
    parser.add_argument('--N_veq_samples',type=int,default=1000,help='number of samples to use to sample v_eq distribution')
    parser.add_argument('--veq_bandwidth',type=float,default=0.1,help='bandwidth for v_eq distribution KDE')
    parser.add_argument('--alpha',default=0.23,help='differential rotation parameter')
    parser.add_argument('--l0',default=20,help='mean active lattitude for spots')
    parser.add_argument('--sigl',default=20,help='standard deviation on spot active lattitude distribution')
    parser.add_argument('--vsini_corrected',action='store_true',help='if this is set, then vsini will not be divided by 1 - (alpha/2)')
    parser.add_argument('--vsini_upperlim',default=None,help='set if vsini measurement is just an upper limit.')
    parser.add_argument('--plotfile',default=None,help='file in which to save cosI summary plot')
    parser.add_argument('--plot_title',default='',help='title of plot')

    args = parser.parse_args()

    if args.vsini_upperlim is None:
        vsini_dist = args.V
    else:
        if not args.vsini_corrected:
            upperlim = float(args.vsini_upperlim) / (1 - float(args.alpha)/2)
        else:
            upperlim = float(args.vsini_upperlim)
        vsini_dist = dists.Box_Distribution(0,upperlim)

    if len(args.R)==1:
        args.R = args.R[0]

    cosi_dist = Cosi_Distribution(args.R,args.P,vsini_dist,nsamples=args.nsamples,
                                  N_veq_samples=args.N_veq_samples,
                                  veq_bandwidth=args.veq_bandwidth,
                                  vsini_corrected=args.vsini_corrected,
                                  alpha=args.alpha,l0=args.l0,sigl=args.sigl)
    cosi_dist.save_hdf(args.filename)

    if args.plotfile is not None:
        cosi_dist.summary_plot(title=args.plot_title)
        plt.savefig(args.plotfile)

if __name__=='__main__':
    main()
