""" Code to average FITS frames

Context : SRP
Module  : SRPAdvAverage.py
Version : 1.3.2
Author  : Stefano Covino
Date    : 25/03/2014
E-mail  : stefano.covino@brera.inaf.it
URL:    : http://www.merate.mi.astro.it/~covino
Purpose : Manage the average of frame FITS files.

Usage   : SRPAdvAverage [-v] [-h] [-e] -i arg1 -o arg2 [-s arg3 arg4] [-x arg5]
            -e Weight for exposure time
            -i Input FITS file list
            -s Sigma-clipping levels (left right)
            -x Input FITS exposure map file list
            -o Output FITS file

            The exposure maps, if available, allow to compensate areas less exposed.

History : (13/11/2008) First version.
        : (16/11/2008) Management of different exposure times.
        : (24/04/2009) Management of lacking keywords.
        : (11/09/2009) Minor correction.
        : (20/08/2010) New sigma clipping average and exposure maps.
        : (14/10/2010) Better import style.
        : (07/08/2011) Better cosmetics.
        : (12/10/2011) Bug correction in Fits header reading.
        : (27/08/2012) Faster sigma-clipping.
        : (27/03/2013) Minor correction in case of exposure maps.
        : (25/03/2014) Deal with non standard FITS headers.
"""



import os, os.path, string, types
from optparse import OptionParser
import SRP.SRPConstants as SRPConstants
import SRP.SRPFiles as SRPFiles
import SRP.SRPUtil as SRPUtil
#import SRP.SRPAstro as SRPAstro
import numpy
import SRP.stats as stats
import pyfits
from SRP.SRPStatistics.AverSigmaClippFrameFast import AverSigmaClippFrameFast
from SRP.SRPStatistics.WeightedMeanFrame import WeightedMeanFrame


parser = OptionParser(usage="usage: %prog [-v] [-h] [-e] -i arg1 -o arg2 [-s arg3 arg4] [-x arg5]", version="%prog 1.3.2")
parser.add_option("-e", "--expweight", action="store_false", dest="expweight", help="Weight for exposure time")
parser.add_option("-i", "--inputlist", action="store", nargs=1, type="string", dest="fitsfilelist", help="Input FITS file list")
parser.add_option("-s", "--sigmaclip", action="store", nargs=2, type="float", dest="sigmaclip", help="Sigma clipping levels (left right)")
parser.add_option("-v", "--verbose", action="store_true", dest="verbose", help="Fully describe operations")
parser.add_option("-x", "--expmaplist", action="store", nargs=1, type="string", dest="expmaplist", help="Input FITS exposure map file list")
parser.add_option("-o", "--outfile", action="store", nargs=1, type="string", dest="outfitsfile", help="Output FITS file")
(options, args) = parser.parse_args()


if options.fitsfilelist and options.outfitsfile:
    sname = SRPFiles.getSRPSessionName()
    if options.verbose:
        print "Session name %s retrieved." % sname
    if os.path.isfile(options.fitsfilelist):
        f = SRPFiles.SRPFile(SRPConstants.SRPLocalDir,options.fitsfilelist,SRPFiles.ReadMode)
        f.SRPOpenFile()
        if options.verbose:
            print "Input FITS file list is: %s." % options.fitsfilelist
    #
        if options.expmaplist:
            if os.path.isfile(options.expmaplist):
                xf = SRPFiles.SRPFile(SRPConstants.SRPLocalDir,options.expmaplist,SRPFiles.ReadMode)
                xf.SRPOpenFile()
                if options.verbose:
                    print "Exposure map file list is: %s." % options.expmaplist
    #
        flist = []
        nentr = 0
    #
        if options.expmaplist:
            xflist = []
            xnentr = 0
    #
        while True:
            dt = f.SRPReadFile()
#
            if options.expmaplist:
                xdt = xf.SRPReadFile()
#
            if dt != '':
                flist.append(string.split(string.strip(dt))[0])
                nentr = nentr + 1
                if not os.path.isfile(flist[nentr-1]):
                    parser.error("Input FITS file %s not found" % flist[nentr-1])
                if options.verbose:
                    print "FITS file selected: %s" % string.split(string.strip(dt))[0]
            else:
                break
#
            if options.expmaplist:
                if xdt != '':
                    xflist.append(string.split(string.strip(xdt))[0])
                    xnentr = xnentr + 1
                    if not os.path.isfile(xflist[xnentr-1]):
                        parser.error("Input exposure map file %s not found" % xflist[xnentr-1])
                    if options.verbose:
                        print "Exposure map file selected: %s" % string.split(string.strip(xdt))[0]
                else:
                    break
    #
        f.SRPCloseFile()
    #
        if options.expmaplist:
            xf.SRPCloseFile()
    #
        if options.expmaplist:
            if nentr != xnentr:
                parser.error("FITS file list and exposure maps do not correspond.")
    #
        if options.sigmaclip:
            if options.sigmaclip[0] <= 0.0 or options.sigmaclip[1] <= 0.0:
                parser.error("Sigma clipping parameters must be positive.")
    #
        if options.verbose:
            print "Computing average..."
        tdata = []
    #
        if options.expmaplist:
            xtdata = []
    #
        thead = []
        shapex = 1e6
        shapey = 1e6
        etime = []
        obstm = []
        for i in flist:
            hdr = pyfits.open(i)
            tdata.append(hdr[0].data)
            thead.append(hdr[0].header)
            try:
                etime.append(hdr[0].header['EXPTIME'])
            except KeyError:
                etime.append(1.0)
            try:
                obstm.append(hdr[0].header['MJD-OBS'])
            except KeyError:
                obstm.append(0.0)
            if type(obstm[-1]) != types.FloatType:
                obstm[-1] = 0.0
            #
            shape = hdr[0].data.shape
            if shape[0] < shapey:
                shapey = shape[0]
            if shape[1] < shapex:
                shapex = shape[1]
    #
        if options.expmaplist:
            for i in xflist:
                xhdr = pyfits.open(i)
                xtdata.append(xhdr[0].data)
    #
#               print shapex, shapey
        newdata = numpy.zeros((shapey,shapex))
    #
        if options.expmaplist:
            xnewdata = numpy.zeros((shapey,shapex))
    #
        tard = numpy.array([tdata[i][:shapey,:shapex] for i in range(len(tdata))])
    #
        if options.expmaplist:
            xtard = numpy.array([xtdata[i][:shapey,:shapex] for i in range(len(xtdata))])
    #
        tottime = stats.sum(etime)
        timearray = numpy.array([numpy.ones((shapey,shapex))*etime[i]/tottime for i in range(len(tdata))])
    #
        if options.sigmaclip:
            if options.expweight:
                res = AverSigmaClippFrameFast(tard,timearray,downsig=options.sigmaclip[0],upsig=options.sigmaclip[1])
            else:
                res = AverSigmaClippFrameFast(tard,downsig=options.sigmaclip[0],upsig=options.sigmaclip[1])            
            newdata = res[0]
            ncond = res[3]
            
            if options.expmaplist:
                res = WeightedMeanFrame(xtard,ncond)
                xnewdata = res[0]

    #            for l in range(shapex):
    #            if options.verbose:
    #                pcr = 100.0*l/shapex
    #                if pcr % 10 < 0.1:
    #                    print "Job completed: %.1f%%" % (100.0*l/shapex)
    #            for m in range(shapey):
    #                pixm = AverIterSigmaClipp([tard[i,m,l] for i in range(len(tdata))],options.sigmaclip)[0]
    #                newdata[m,l] = pixm
        else:
            #for i in range(len(tdata)):
            #    tempdata = numpy.multiply(tdata[i][:shapey,:shapex],etime[i]/tottime)
            #    newdata = numpy.add(newdata,tempdata)
            if options.expweight:
                res = WeightedMeanFrame(tard,timearray)
            else:
                res = WeightedMeanFrame(tard)
            newdata = res[0]
#
            if options.expmaplist:
                #for i in range(len(xtdata)):
                #    xtempdata = numpy.multiply(xtdata[i][:shapey,:shapex],etime[i]/tottime)
                #    xnewdata = numpy.add(xnewdata,xtempdata)
                if options.expweight:
                    res = WeightedMeanFrame(xtard,timearray)
                else:
                    res = WeightedMeanFrame(xtard)                
                xnewdata = res[0]
#
        if options.expmaplist:
            newdata = numpy.divide(newdata,xnewdata)
        else:
            newdata = numpy.multiply(newdata,1.0)
    #
        if options.verbose:
            print "Saving average file: %s" % sname+options.outfitsfile
    #
        if options.expmaplist and options.verbose:
            frot,frxt = os.path.splitext(options.outfitsfile)
            print "Saving average exposure map: %s" % sname+frot+SRPConstants.SRPExpMap+frxt
    #
        nfts = pyfits.PrimaryHDU(newdata,thead[0])
        nftlist = pyfits.HDUList([nfts])
        nftlist[0].header.update('hierarch '+SRPConstants.SRPCategory,SRPConstants.SRPSCIENCE,SRPConstants.SRPCatComm)
        nftlist[0].header.update('hierarch '+SRPConstants.SRPNFiles,len(tdata),SRPConstants.SRPNFilesComm)
        if options.sigmaclip:
            nftlist[0].header.update('hierarch '+SRPConstants.SRPMethod,SRPConstants.SRPAVERAGESC % (options.sigmaclip[0], options.sigmaclip[1]),SRPConstants.SRPMethodComm)
        else:
            nftlist[0].header.update('hierarch '+SRPConstants.SRPMethod,SRPConstants.SRPAVERAGE,SRPConstants.SRPMethodComm)
        nftlist[0].header.update('EXPTIME',stats.mean(etime))
#               tottime = stats.sum(etime)
        if options.verbose:
            print "Total observing time: %.1fs" % tottime
        if tottime != 0.0:
            for i in range(len(obstm)):
                obstm[i] = obstm[i]*etime[i]/tottime
        nftlist[0].header.update('MJD-OBS',stats.sum(obstm))
        if options.verbose:
            nftlist.writeto(sname+options.outfitsfile,clobber=True,output_verify='warn')
        else:
            nftlist.writeto(sname+options.outfitsfile,clobber=True,output_verify='ignore')
        if options.expmaplist:
            xnfts = pyfits.PrimaryHDU(xnewdata,thead[0])
            xnftlist = pyfits.HDUList([xnfts])
            if options.verbose:
                xnftlist.writeto(sname+frot+SRPConstants.SRPExpMap+frxt,clobber=True,output_verify='warn')
            else:
                xnftlist.writeto(sname+frot+SRPConstants.SRPExpMap+frxt,clobber=True,output_verify='ignore')                
    else:
        parser.error("Input FITS file list %s not found" % options.fitsfilelist)
else:
    parser.print_help()
