""" Code to average FITS frames

Context : SRP
Module  : SRPAdvAverage.py
Version : 1.3.1
Author  : Stefano Covino
Date    : 27/03/2013
E-mail  : stefano.covino@brera.inaf.it
URL:    : http://www.merate.mi.astro.it/~covino
Purpose : Manage the average of frame FITS files.

Usage   : SRPAdvAverage [-v] [-h] [-e] -i arg1 -o arg2 [-s arg3 arg4] [-x arg5]
            -e Weight for exposure time
            -i Input FITS file list
            -s Sigma-clipping levels (left right)
            -x Input FITS exposure map file list
            -o Output FITS file

            The exposure maps, if available, allow to compensate areas less exposed.

History : (13/11/2008) First version.
        : (16/11/2008) Management of different exposure times.
        : (24/04/2009) Management of lacking keywords.
        : (11/09/2009) Minor correction.
        : (20/08/2010) New sigma clipping average and exposure maps.
        : (14/10/2010) Better import style.
        : (07/08/2011) Better cosmetics.
        : (12/10/2011) Bug correction in Fits header reading.
        : (27/08/2012) Faster sigma-clipping.
        : (27/03/2013) Minor correction in case of exposure maps.
"""



import os, os.path, string, types
from optparse import OptionParser
import SRP.SRPConstants as SRPConstants
import SRP.SRPFiles as SRPFiles
import SRP.SRPUtil as SRPUtil
import SRP.SRPAstro as SRPAstro
import numpy
import SRP.stats as stats
import pyfits
from SRP.SRPStatistics.AverSigmaClippFrameFast import AverSigmaClippFrameFast
from SRP.SRPStatistics.WeightedMeanFrame import WeightedMeanFrame


parser = OptionParser(usage="usage: %prog [-v] [-h] [-e] -i arg1 -o arg2 [-s arg3 arg4] [-x arg5]", version="%prog 1.3.1")
parser.add_option("-e", "--expweight", action="store_false", dest="expweight", help="Weight for exposure time")
parser.add_option("-i", "--inputlist", action="store", nargs=1, type="string", dest="fitsfilelist", help="Input FITS file list")
parser.add_option("-s", "--sigmaclip", action="store", nargs=2, type="float", dest="sigmaclip", help="Sigma clipping levels (left right)")
parser.add_option("-v", "--verbose", action="store_true", dest="verbose", help="Fully describe operations")
parser.add_option("-x", "--expmaplist", action="store", nargs=1, type="string", dest="expmaplist", help="Input FITS exposure map file list")
parser.add_option("-o", "--outfile", action="store", nargs=1, type="string", dest="outfitsfile", help="Output FITS file")
(options, args) = parser.parse_args()


if options.fitsfilelist and options.outfitsfile:
    sname = SRPFiles.getSRPSessionName()
    if options.verbose:
        print "Session name %s retrieved." % sname
    if os.path.isfile(options.fitsfilelist):
        f = SRPFiles.SRPFile(SRPConstants.SRPLocalDir,options.fitsfilelist,SRPFiles.ReadMode)
        f.SRPOpenFile()
        if options.verbose:
            print "Input FITS file list is: %s." % options.fitsfilelist
    #
        if options.expmaplist:
            if os.path.isfile(options.expmaplist):
                xf = SRPFiles.SRPFile(SRPConstants.SRPLocalDir,options.expmaplist,SRPFiles.ReadMode)
                xf.SRPOpenFile()
                if options.verbose:
                    print "Exposure map file list is: %s." % options.expmaplist
    #
        flist = []
        nentr = 0
    #
        if options.expmaplist:
            xflist = []
            xnentr = 0
    #
        while True:
            dt = f.SRPReadFile()
#
            if options.expmaplist:
                xdt = xf.SRPReadFile()
#
            if dt != '':
                flist.append(string.split(string.strip(dt))[0])
                nentr = nentr + 1
                if not os.path.isfile(flist[nentr-1]):
                    parser.error("Input FITS file %s not found" % flist[nentr-1])
                if options.verbose:
                    print "FITS file selected: %s" % string.split(string.strip(dt))[0]
            else:
                break
#
            if options.expmaplist:
                if xdt != '':
                    xflist.append(string.split(string.strip(xdt))[0])
                    xnentr = xnentr + 1
                    if not os.path.isfile(xflist[xnentr-1]):
                        parser.error("Input exposure map file %s not found" % xflist[xnentr-1])
                    if options.verbose:
                        print "Exposure map file selected: %s" % string.split(string.strip(xdt))[0]
                else:
                    break
    #
        f.SRPCloseFile()
    #
        if options.expmaplist:
            xf.SRPCloseFile()
    #
        if options.expmaplist:
            if nentr != xnentr:
                parser.error("FITS file list and exposure maps do not correspond.")
    #
        if options.sigmaclip:
            if options.sigmaclip[0] <= 0.0 or options.sigmaclip[1] <= 0.0:
                parser.error("Sigma clipping parameters must be positive.")
    #
        if options.verbose:
            print "Computing average..."
        tdata = []
    #
        if options.expmaplist:
            xtdata = []
    #
        thead = []
        shapex = 1e6
        shapey = 1e6
        etime = []
        obstm = []
        for i in flist:
            hdr = pyfits.open(i)
            tdata.append(hdr[0].data)
            thead.append(hdr[0].header)
            try:
                etime.append(hdr[0].header['EXPTIME'])
            except KeyError:
                etime.append(1.0)
            try:
                obstm.append(hdr[0].header['MJD-OBS'])
            except KeyError:
                obstm.append(0.0)
            if type(obstm[-1]) != types.FloatType:
                obstm[-1] = 0.0
            #
            shape = hdr[0].data.shape
            if shape[0] < shapey:
                shapey = shape[0]
            if shape[1] < shapex:
                shapex = shape[1]
    #
        if options.expmaplist:
            for i in xflist:
                xhdr = pyfits.open(i)
                xtdata.append(xhdr[0].data)
    #
#               print shapex, shapey
        newdata = numpy.zeros((shapey,shapex))
    #
        if options.expmaplist:
            xnewdata = numpy.zeros((shapey,shapex))
    #
        tard = numpy.array([tdata[i][:shapey,:shapex] for i in range(len(tdata))])
    #
        if options.expmaplist:
            xtard = numpy.array([xtdata[i][:shapey,:shapex] for i in range(len(xtdata))])
    #
        tottime = stats.sum(etime)
        timearray = numpy.array([numpy.ones((shapey,shapex))*etime[i]/tottime for i in range(len(tdata))])
    #
        if options.sigmaclip:
            if options.expweight:
                res = AverSigmaClippFrameFast(tard,timearray,downsig=options.sigmaclip[0],upsig=options.sigmaclip[1])
            else:
                res = AverSigmaClippFrameFast(tard,downsig=options.sigmaclip[0],upsig=options.sigmaclip[1])            
            newdata = res[0]
            ncond = res[3]
            
            if options.expmaplist:
                res = WeightedMeanFrame(xtard,ncond)
                xnewdata = res[0]

    #            for l in range(shapex):
    #            if options.verbose:
    #                pcr = 100.0*l/shapex
    #                if pcr % 10 < 0.1:
    #                    print "Job completed: %.1f%%" % (100.0*l/shapex)
    #            for m in range(shapey):
    #                pixm = AverIterSigmaClipp([tard[i,m,l] for i in range(len(tdata))],options.sigmaclip)[0]
    #                newdata[m,l] = pixm
        else:
            #for i in range(len(tdata)):
            #    tempdata = numpy.multiply(tdata[i][:shapey,:shapex],etime[i]/tottime)
            #    newdata = numpy.add(newdata,tempdata)
            if options.expweight:
                res = WeightedMeanFrame(tard,timearray)
            else:
                res = WeightedMeanFrame(tard)
            newdata = res[0]
#
            if options.expmaplist:
                #for i in range(len(xtdata)):
                #    xtempdata = numpy.multiply(xtdata[i][:shapey,:shapex],etime[i]/tottime)
                #    xnewdata = numpy.add(xnewdata,xtempdata)
                if options.expweight:
                    res = WeightedMeanFrame(xtard,timearray)
                else:
                    res = WeightedMeanFrame(xtard)                
                xnewdata = res[0]
#
        if options.expmaplist:
            newdata = numpy.divide(newdata,xnewdata)
        else:
            newdata = numpy.multiply(newdata,1.0)
    #
        if options.verbose:
            print "Saving average file: %s" % sname+options.outfitsfile
    #
        if options.expmaplist and options.verbose:
            frot,frxt = os.path.splitext(options.outfitsfile)
            print "Saving average exposure map: %s" % sname+frot+SRPConstants.SRPExpMap+frxt
    #
        nfts = pyfits.PrimaryHDU(newdata,thead[0])
        nftlist = pyfits.HDUList([nfts])
        nftlist[0].header.update('hierarch '+SRPConstants.SRPCategory,SRPConstants.SRPSCIENCE,SRPConstants.SRPCatComm)
        nftlist[0].header.update('hierarch '+SRPConstants.SRPNFiles,len(tdata),SRPConstants.SRPNFilesComm)
        if options.sigmaclip:
            nftlist[0].header.update('hierarch '+SRPConstants.SRPMethod,SRPConstants.SRPAVERAGESC % (options.sigmaclip[0], options.sigmaclip[1]),SRPConstants.SRPMethodComm)
        else:
            nftlist[0].header.update('hierarch '+SRPConstants.SRPMethod,SRPConstants.SRPAVERAGE,SRPConstants.SRPMethodComm)
        nftlist[0].header.update('EXPTIME',stats.mean(etime))
#               tottime = stats.sum(etime)
        if options.verbose:
            print "Total observing time: %.1fs" % tottime
        if tottime != 0.0:
            for i in range(len(obstm)):
                obstm[i] = obstm[i]*etime[i]/tottime
        nftlist[0].header.update('MJD-OBS',stats.sum(obstm))
        nftlist.writeto(sname+options.outfitsfile,clobber=True)
        if options.expmaplist:
            xnfts = pyfits.PrimaryHDU(xnewdata,thead[0])
            xnftlist = pyfits.HDUList([xnfts])
            xnftlist.writeto(sname+frot+SRPConstants.SRPExpMap+frxt,clobber=True)
    else:
        parser.error("Input FITS file list %s not found" % options.fitsfilelist)
else:
    parser.print_help()
