""" Code to derive instrumental Stokes parameters
    
Context : SRP
Module  : SRPTNGPAOLOInstrStokes
Author  : Stefano Covino
Date    : 09/04/2013
E-mail  : stefano.covino@brera.inaf.it
URL:    : http://www.merate.mi.astro.it/utenti/covino
Purpose : Derive instrumental Q and U Stokes parameters

Usage   : SRPTNGPAOLOInstrStokes [-h] [-a] [-c Q eQ U eU V eV] -f file -o file
            [-v] [--version] [-w wave]
            -a Append data to output
            -c Calibration Q, U and V values
            -f Input FITS photometry file
            -o Output FITS file
            -w Observation wavelength (micron)
    
History : (29/02/2012) First version.
        : (02/08/2012) Minor bug.
        : (27/09/2012) MJD in output.
        : (25/03/2013) Analyze spectra too.
        : (09/04/2013) Bug correction for photometric files.

"""

__version__ = '1.1.0'


import argparse, math, os, sys, warnings
import atpy, ephem, numpy
import SRP.SRPTNG as ST
import SRP.SRPTNG.PAOLO as STP
from SRP.SRPTNG.PAOLO.AverHourAngle import AverHourAngle
from SRP.SRPTNG.PAOLO.AverParallacticAngle import AverParallacticAngle
from SRP.SRPTNG.GetObj import GetObj
from SRP.SRPTNG.GetTNGSite import GetTNGSite
from SRP.SRPPhotometry.Mag2Counts import Mag2Counts



parser = argparse.ArgumentParser()
parser.add_argument("-a", "--append", action="store_true", help="Append data to output")
parser.add_argument("-c", "--calibqu", action="store", type=float, nargs=6, help="Calibration Q, U and V values", metavar=('Q', 'eQ', 'U', 'eU', 'V', 'eV'))
parser.add_argument("-f", "--fitsphotfile", action="store", help="Input FITS photometry/spectroscopy file", required=True, metavar='file')
parser.add_argument("-o", "--outfile", action="store", help="Output FITS file", required=True, metavar='file')
parser.add_argument("-v", "--verbose", action="store_true", help="Fully describe operations")
parser.add_argument("--version", action="version", version=__version__)
parser.add_argument("-w", "--wave", action="store", type=float, help="Observation wavelength (micron)", metavar='wave')
options = parser.parse_args()


#
try:
    tphot = atpy.Table(options.fitsphotfile, type='fits')
except IOError:
    parser.error("Invalid input FITS file.")
if options.verbose:
    print "Input FITS photometry file: %s" % options.fitsphotfile
#
try:
    sequence = tphot.keywords[STP.SEQUENCE].split()
    filter = tphot.keywords[STP.FILTER]
    ra = tphot.keywords[STP.RA]
    dec = tphot.keywords[STP.DEC]
    date = tphot.keywords[STP.DATE]
    time = tphot.keywords[STP.TIME]
    expt = tphot.keywords[STP.EXPTIME]
    posang = tphot.keywords[STP.POSANG]
    object = tphot.keywords[STP.OBJECT]
    pstop = tphot.keywords[STP.PSTOP]
    polslide = tphot.keywords[STP.POLSLIDE]
    rot4 = tphot.keywords[STP.ROTLAM4]
    rot2 = tphot.keywords[STP.ROTLAM2]
    mjd = tphot.keywords[STP.MJD]
except Exception:
    parser.error("Invalid data in FITS table.")    
#
nospec = False
nophot = False
Fl = []
eFl = []
#
for ii in sequence:
    try:
        Fl.append(tphot[STP.Flux+'_'+ii])
        eFl.append(tphot[STP.eFlux+'_'+ii])
        nophot = True
    except ValueError:
        nospec = True

if nospec:
    for ii in sequence:
        try:
            fl, efl = Mag2Counts(tphot[STP.Mag+'_'+ii], tphot[STP.eMag+'_'+ii])
            nospec = True
        except ValueError:
            nophot = True
        Fl.append(numpy.array(fl))
        eFl.append(numpy.array(efl))
#
if nospec and nophot:
    parser.error("Invalid columns %s,%s in FITS table" % (STP.Mag+'_'+ii, STP.eMag+'_'+ii))
#
warnings.resetwarnings()
warnings.filterwarnings('ignore', category=RuntimeWarning, append=True)
Q = (Fl[0]-Fl[2])/(Fl[0]+Fl[2])
U = (Fl[1]-Fl[3])/(Fl[1]+Fl[3])
eQ = numpy.fabs(Q) * numpy.sqrt( ((eFl[0]**2+eFl[2]**2)/(Fl[0]-Fl[2])**2) + ((eFl[0]**2+eFl[2]**2)/(Fl[0]+Fl[2])**2) )
eU = numpy.fabs(U) * numpy.sqrt( ((eFl[1]**2+eFl[3]**2)/(Fl[1]-Fl[3])**2) + ((eFl[1]**2+eFl[3]**2)/(Fl[1]+Fl[3])**2) )
warnings.resetwarnings()
warnings.filterwarnings('always', category=RuntimeWarning, append=True)
#
QNB = numpy.where(numpy.isnan(Q) | numpy.isinf(Q), False, True)
UNB = numpy.where(numpy.isnan(U) | numpy.isinf(U), False, True)
eQNB = numpy.where(numpy.isnan(eQ) | numpy.isinf(eQ), False, True)
eUNB = numpy.where(numpy.isnan(eU) | numpy.isinf(eU), False, True)
QUflag = numpy.where(QNB & eQNB & UNB & eUNB, True, False)
#
Qf = Q[QUflag]
Uf = U[QUflag]
eQf = eQ[QUflag]
eUf = eU[QUflag]
#if polslide.upper().find(STP.LAMBDA2) >= 0:
#    if options.verbose:
#        print "Lambda/2 correction applied..."
#    nQ = []
#    nU = []
#    neQ = []
#    neU = []
#    for el in range(len(Q)):
#        sto = numpy.matrix([1.,Q[el],U[el],0.0]).transpose()
#        nsto = MuellerHalfWavePlateMatrix(math.radians(rot2)).I*sto
#        nQ.append(nsto[1,0])
#        nU.append(nsto[2,0])
#        lQ = GenGaussSet(Q[el],eQ[el],1000)
#        lU = GenGaussSet(U[el],eU[el],1000)
#        lsQ = []
#        lsU = []
#        for i in range(1000):
#            sto = numpy.matrix([1.,lQ[i],lU[i],0.0]).transpose()
#            nsto = MuellerHalfWavePlateMatrix(math.radians(rot2)).I*sto
#            lsQ.append(nsto[1,0])
#            lsU.append(nsto[2,0])
#        neQ.append(ScoreatPercentile(lsQ)[3])
#        neU.append(ScoreatPercentile(lsU)[3])
#    Q = nQ
#    eQ = neQ
#    U = nU
#    eU = neU
#
tnew = atpy.Table()
tnew.add_column(STP.Id,tphot[STP.Id+'_'+sequence[0]][QUflag])
tnew.add_column(STP.X,tphot[STP.X+'_'+sequence[0]][QUflag])
tnew.add_column(STP.Y,tphot[STP.Y+'_'+sequence[0]][QUflag])
tnew.add_column(STP.OBJECT,object, dtype=numpy.dtype('|S25'))
tnew.add_column(STP.Q,Qf)
tnew.add_column(STP.eQ,eQf)
tnew.add_column(STP.U,Uf)
tnew.add_column(STP.eU,eUf)
#tnew.add_column(STP.V,0.0)
#tnew.add_column(STP.eV,0.0)
tnew.add_column(STP.POLSLIDE,polslide)
tnew.add_column(STP.ROTLAM4,rot4)
tnew.add_column(STP.ROTLAM2,rot2)
tnew.add_column(STP.MJD,mjd)
if nospec:
    tnew.add_column(STP.TotMag,tphot[STP.TotMag][QUflag])
    tnew.add_column(STP.eTotMag,tphot[STP.eTotMag][QUflag])
elif nophot:
    tnew.add_column(STP.TotFlux,tphot[STP.TotFlux][QUflag])
    tnew.add_column(STP.eTotFlux,tphot[STP.eTotFlux][QUflag])
#
warnings.resetwarnings()
warnings.filterwarnings('ignore', category=DeprecationWarning, append=True)
site = GetTNGSite()
nb = GetObj(ra,dec)
site.date = ephem.Date(date+' '+time)
warnings.resetwarnings()
warnings.filterwarnings('always', category=DeprecationWarning, append=True)
#
hourangle = AverHourAngle(nb,site,expt)
tnew.add_column(STP.HOURANG,numpy.array(len(Qf)*[hourangle]))
if options.verbose:
    print "Observation hour angle: %.1f" % hourangle
#
parangle = AverParallacticAngle(nb,site,expt)
tnew.add_column(STP.PARANG,numpy.array(len(Qf)*[parangle]))
if options.verbose:
    print "Observation parallactic Angle: %.1f" % parangle
#
if nospec:
    if options.wave:
        wave = options.wave
    else:
        try:
            wave = ST.LRSFiltCentrWaveDict[filter]
        except KeyError:
            wave = 0.55
    tnew.add_column(STP.WAVE,numpy.array(len(Qf)*[wave]))
elif nophot:
    if options.wave:
        wave = options.wave
    else:
        try:
            wave = tphot[STP.WAVE]*1e-4
        except KeyError:
            wave = 0.55
    tnew.add_column(STP.WAVE,wave[QUflag])
if options.verbose:
    if nospec:
        print "Observation wavelength: %.3f" % wave
    elif nophot:
        print "Observayion wavelength: spectral range"
#
tnew.add_column(STP.POSANG,posang)
if options.verbose:
    print "Derotator offset: %.1f" % posang
#
if options.verbose:
    print "Pupil stop: %s" % pstop
    print "Plate     : %s" % polslide
    print "Lambda/4  : %.1f" % rot4
    print "Lambda/2  : %.1f" % rot2
#

if options.calibqu:
    tnew.add_column(STP.CalQ,options.calibqu[0])
    tnew.add_column(STP.eCalQ,options.calibqu[1])
    tnew.add_column(STP.CalU,options.calibqu[2])
    tnew.add_column(STP.eCalU,options.calibqu[3])
    tnew.add_column(STP.CalV,options.calibqu[4])
    tnew.add_column(STP.eCalV,options.calibqu[5])
    if options.verbose:
        print "Calibrated Stokes parameters added."
#
if options.append and os.path.exists(options.outfile):
    tnew.write(STP.tempfile,type='fits',overwrite=True)
    tnew2 = atpy.Table(STP.tempfile,type='fits')
    os.remove(STP.tempfile)
    #
    try:
        tapp = atpy.Table(options.outfile, type='fits')
    except IOError:
        parser.error("Invalid FITS file to append.")
    try:
        tapp.append(tnew2)
    except ValueError:
        parser.error("Tables to be appended are not compatible.")
    tapp.write(options.outfile,type='fits',overwrite=True)
else:
    tnew.write(options.outfile,type='fits',overwrite=True)
#
if options.verbose:
    print "%d (new) entries saved in file %s" % (len(tnew), options.outfile)
else:
    print "%d %s" % (len(tnew), options.outfile)
#
