""" Code to derive model parameters
    
Context : SRP
Module  : SRPTNGPAOLOParamFit
Author  : Stefano Covino
Date    : 25/09/2013
E-mail  : stefano.covino@brera.inaf.it
URL:    : http://www.merate.mi.astro.it/utenti/covino
Purpose : Derive TNG polarimetric model parameters.

Usage   : SRPTNGPAOLOParamFit [-h] [-a n k] [-d detoff] [-f file] -i file
                           [-o file] [-s syst] [-v] [--version] [-z q0 u0 v0]
            -a Aluminium refractive and extinction coefficient multiplicative factors
            -d Detector offset (deg)
            -f Input FITS Stokes parameter file
            -i Input normalized instrumental Stokes parameter file
            -o Output fit parameter FITS file
            -s Systematic error to be added
            -z Normalized instrumental polarization

    
History : (01/03/2012) First version.
        : (27/09/2012) Better output.
        : (29/11/2012) Correct sign for position angle.
        : (04/12/2012) Possibility to choose the fit parameters.
        : (31/03/2013) Better help message.
        : (24/07/2013) Total intensity in polarization computation.
        : (25/09/2013) Update for the latest atpy version.
"""

__version__ = '0.3.1'


import argparse
import atpy, numpy
from scipy.optimize import minimize
import SRP.stats as stats
from SRP.SRPStatistics.GenFitPars import GenFitPars
from SRP.SRPPolarimetry.AluminiumRefractiveIndex import AluminiumRefractiveIndex
import SRP.SRPTNG.PAOLO as STP
from SRP.SRPTNG.PAOLO.TNGMuellerMatrix import TNGMuellerMatrix
from SRP.SRPTNG.PAOLO.TNGMuellerMatrixPlate2 import TNGMuellerMatrixPlate2
from SRP.SRPTNG.PAOLO.TNGMuellerMatrixPlate4 import TNGMuellerMatrixPlate4
from SRP.SRPTNG.PAOLO.StokesOffsetVector import StokesOffsetVector

    


parser = argparse.ArgumentParser()
parser.add_argument("-a", "--alum", action="store", nargs=2, type=float, help="Aluminium refractive and extinction coefficient multiplicative factors", metavar=('n','k'),default=(0.9,0.9))
parser.add_argument("-c", "--choice", action="store", nargs=6, type=int, help="Choices for fitting: n, k, detoff, q0, u0, v0 [default=(1 -1 1 1 1 1)]", metavar=('n','k','detoff','q0','u0','v0'), default=(1,-1,1,1,1,0))
parser.add_argument("-d", "--detoff", action="store", type=float, help="Detector offset (deg)", metavar='detoff',default=0.5)
parser.add_argument("-f", "--fitsfile", action="store", help="Input FITS Stokes parameter file", metavar='file')
parser.add_argument("-i", "--instrpolfile", action="store", help="Input normalized instrumental Stokes parameter file", metavar='file', required=True)
parser.add_argument("-o", "--outfile", action="store", help="Output fit parameter FITS file", metavar='file')
parser.add_argument("-s", "--syst", action="store", type=float, default=0.0, help="Systematic error to be added", metavar=('syst'))
parser.add_argument("-v", "--verbose", action="store_true", help="Fully describe operations")
parser.add_argument("--version", action="version", version=__version__)
parser.add_argument("-z", "--zero", action="store", nargs=3, type=float, help="Normalized instrumental polarization", metavar=('q0','u0','v0'),default=(0.,0.,0.))
options = parser.parse_args()


#
try:
    dtp = atpy.Table(options.instrpolfile, type='fits')
except IOError:
    parser.error("Invalid input instrumental FITS Stokes parameter file.")
if options.verbose:
    print "Input instrumental Stokes parameter file: %s" % options.instrpolfile
#
lambd2 = False
lambd4 = False
for i in dtp:
    try:
        if i[STP.POLSLIDE].upper().find(STP.LAMBDA2) >= 0:
            lambd2 = True
            break
        elif i[STP.POLSLIDE].upper().find(STP.LAMBDA4) >= 0:
            lambd4 = True
            break
    except IndexError:
        parser.error("Table %s format not corrected." % options.instrpolfile)
#        
if options.fitsfile:
    try:
        dt = atpy.Table(options.fitsfile, type='fits')
    except IOError:
        parser.error("Invalid input Stokes parameter file.")
    if options.verbose:
        print "Input FITS Stokes parameter file: %s" % options.fitsfile
    #
    try:
        nn = dt[STP.N][0]
        kk = dt[STP.K][0]
        offoff = dt[STP.DETOFF][0]
        q0q0 = dt[STP.Q0][0]
        u0u0 = dt[STP.U0][0]
        v0v0 = dt[STP.V0][0]
    except IndexError:
        parser.errpr("FITS Stokes parameter file without the expected entries.")
else:
    nn = options.alum[0]
    kk = options.alum[1]
    offoff = options.detoff
    q0q0 = options.zero[0]
    u0u0 = options.zero[1]
    v0v0 = options.zero[2]
#
if q0q0 > 1. or u0u0 > 1 or v0v0 > 1 or (q0q0**2 + u0u0**2 + v0v0**2) > 1:
    parser.error("Unrealistic instrumental polarization.")
if nn <= 0.0 or kk <= 0:
    parser.error("Multiplicative factors must be positive.")
#
if options.syst < 0.0:
    parser.error("Systematic error must be positive.")
#
dtp[STP.eQ] = numpy.sqrt((dtp[STP.eQ])**2 + options.syst**2)
dtp[STP.eU] = numpy.sqrt((dtp[STP.eU])**2 + options.syst**2)
#
if options.verbose:
    print "Refractive index multiplicative factor      : %.3f" % nn
    print "Extinction coefficient multiplicative factor: %.3f" % kk
    print "Detector offset (deg)                       : %.2f" % offoff
    print "Instrumental polarization                   : Q0=%.3g, U0=%.3g, V0=%.3g" % (q0q0, u0u0, v0v0)
    if options.syst > 0:
        print "Systematic error                            : %.3g" % options.syst
    print "Fit rules                                   : ", str(options.choice).strip('[]')
#
def func (vars,pars,args):
    pari = GenFitPars(pars,args)
    wave = vars[0]
    cq = vars[1]
    cu = vars[2]
    pa = vars[3]
    p = vars[4]
    fn = pari[0]
    fk = pari[1]
    off = -p+pari[2]
    q0 = pari[3]
    u0 = pari[4]
    #
    nf, kf = AluminiumRefractiveIndex()
    n = nf(wave)
    k = kf(wave)
    sto = [1.0, cq, cu, 0.0]
    Stokes = numpy.matrix(sto).transpose()
    s = TNGMuellerMatrix(pa,fn*n,fk*k,off)*Stokes+StokesOffsetVector(q0,u0,0.0)
    return s
#
def func2 (vars,pars,args):
    pari = GenFitPars(pars,args)
    wave = vars[0]
    cq = vars[1]
    cu = vars[2]
    pa = vars[3]
    p = vars[4]
    rot = vars[5]
    fn = pari[0]
    fk = pari[1]
    off = -p+pari[2]
    q0 = pari[3]
    u0 = pari[4]
    #
    nf, kf = AluminiumRefractiveIndex()
    n = nf(wave)
    k = kf(wave)
    sto = [1.0, cq, cu, 0.0]
    Stokes = numpy.matrix(sto).transpose()
    s = TNGMuellerMatrixPlate2(pa,fn*n,fk*k,rot,off)*Stokes+StokesOffsetVector(q0,u0,0.0)
    return s
#
def func4 (vars,pars,args):
    pari = GenFitPars(pars,args)
    wave = vars[0]
    cv = vars[1]
    pa = vars[2]
    p = vars[3]
    rot = vars[4]
    fn = parsi[0]
    fk = pari[1]
    off = -p+pari[2]
    v0 = pari[3]
    #
    nf, kf = AluminiumRefractiveIndex()
    n = nf(wave)
    k = kf(wave)
    sto = [1.0, 0.0, 0.0, cv]
    Stokes = numpy.matrix(sto).transpose()
    s = TNGMuellerMatrixPlate4(pa,fn*n,fk*k,rot,off)*Stokes+StokesOffsetVector(0.0,0.0,v0)
    return s
#
def chi2 (pars, args):
    chiq = 0.0
    chiu = 0.0
    chiv = 0.0
    for i in dtp:
        try:
            if lambd2:
                s = func2((i[STP.WAVE],i[STP.CalQ],i[STP.CalU],i[STP.PARANG],i[STP.POSANG],i[STP.ROTLAM2]),pars,args)
                chiq = chiq + (((s[1,0]/s[0,0]-i[STP.Q])/i[STP.eQ])**2)
                chiu = chiu + (((s[2,0]/s[0,0]-i[STP.U])/i[STP.eU])**2)
            elif lambd4: 
                s = func4((i[STP.WAVE],i[STP.CalV],i[STP.PARANG],i[STP.POSANG],i[STP.ROTLAM4]),pars,args)
                chiv = chiv + (((s[3,0]/s[0,0]-i[STP.Q])/i[STP.eQ])**2)
            else:
                s = func((i[STP.WAVE],i[STP.CalQ],i[STP.CalU],i[STP.PARANG],i[STP.POSANG]),pars,args)
                chiq = chiq + (((s[1,0]/s[0,0]-i[STP.Q])/i[STP.eQ])**2)
                chiu = chiu + (((s[2,0]/s[0,0]-i[STP.U])/i[STP.eU])**2)
        except IndexError:
            parser.error("Table %s format not corrected." % options.instrpolfile)
    return chiq + chiu + chiv
#
if lambd4:
    inizio = [nn,kk,offoff,v0v0]
    ags = [(options.choice[0],inizio[0]),(options.choice[1],inizio[1]),(options.choice[2],inizio[2]),(options.choice[5],inizio[3])]
else:
    inizio = [nn,kk,offoff,q0q0,u0u0]
    ags = [(options.choice[0],inizio[0]),(options.choice[1],inizio[1]),(options.choice[2],inizio[2]),(options.choice[3],inizio[3]),(options.choice[4],inizio[4])]
#
parst = minimize (chi2, inizio, args=(ags,), method='Nelder-Mead', options={'disp':False}, tol=1e-4)
pars = GenFitPars(parst.x,ags)
tchi = chi2(pars,ags)
npr = sum([iii[0]>0 for iii in ags])
if lambd4:
    ndf = len(dtp)-npr
else:
    ndf = 2*len(dtp)-npr
#
if options.verbose:
    print "Fit reduced CHI2, dof, CHI2: %.2f %d %.2f" % ((tchi/ndf), ndf, tchi)
    print "Fit probability            : %.2f%%" % (100*stats.chisqprob(float(tchi),ndf))
#
if options.verbose:
    print "Fit refractive index multiplicative factor      : %.3f" % pars[0]
    print "Fit extinction coefficient multiplicative factor: %.3f" % pars[1]
    print "Fit Detector offset (deg)                       : %.2f" % pars[2]
    if lambd4:
        print "Fit instrumental polarization                   : V0=%.3g" % (pars[3])
    else:
        print "Fit instrumental polarization                   : Q0=%.3g, U0=%.3g" % (pars[3], pars[4])
else:
    if lambd4:
        print "%.3f %.3f %.2f %.2g %.3f" % (pars[0], pars[1], pars[2], pars[3], (tchi/ndf))
    else:
        print "%.3f %.3f %.2f %.2g %.2g %.3f" % (pars[0], pars[1], pars[2], pars[3], pars[4], (tchi/ndf))
#
if options.outfile:
    tout = atpy.Table(name=options.outfile)
    tout.add_column(STP.N,numpy.array([pars[0]]))
    tout.add_column(STP.K,numpy.array([pars[1]]))
    tout.add_column(STP.DETOFF,numpy.array([pars[2]]))
    if lambd4:
        tout.add_column(STP.Q0,numpy.array([0.0]))
        tout.add_column(STP.U0,numpy.array([0.0]))
        tout.add_column(STP.V0,numpy.array(pars[3]))
    else:
        tout.add_column(STP.Q0,numpy.array([pars[3]]))
        tout.add_column(STP.U0,numpy.array([pars[4]]))
        tout.add_column(STP.V0,numpy.array([0.0]))
    tout.add_column(STP.CHI2,numpy.array([tchi/ndf]))    
    tout.write(options.outfile,type='fits',overwrite=True)
    if options.verbose:
        print "Fit parameters saved in file %s" % options.outfile
    else:
        print "%s" % options.outfile   
#
